feat(router): use number of tokens in batch as input for dynamic batching (#226)
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
98a3e0d135
commit
ebc74d5666
|
@ -39,8 +39,12 @@ struct Args {
|
|||
max_input_length: usize,
|
||||
#[clap(default_value = "1512", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "32000", long, env)]
|
||||
max_batch_total_tokens: u32,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
|
@ -93,6 +97,8 @@ fn main() -> ExitCode {
|
|||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_batch_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
shard_uds_path,
|
||||
|
@ -380,8 +386,8 @@ fn main() -> ExitCode {
|
|||
max_input_length.to_string(),
|
||||
"--max-total-tokens".to_string(),
|
||||
max_total_tokens.to_string(),
|
||||
"--max-batch-size".to_string(),
|
||||
max_batch_size.to_string(),
|
||||
"--waiting-served-ratio".to_string(),
|
||||
waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
max_waiting_tokens.to_string(),
|
||||
"--port".to_string(),
|
||||
|
@ -392,6 +398,15 @@ fn main() -> ExitCode {
|
|||
model_id,
|
||||
];
|
||||
|
||||
// Deprecate max_batch_size
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
argv.push("--max-batch-size".to_string());
|
||||
argv.push(max_batch_size.to_string())
|
||||
} else {
|
||||
argv.push("--max-batch-total-tokens".to_string());
|
||||
argv.push(max_batch_total_tokens.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = revision {
|
||||
argv.push("--revision".to_string());
|
||||
|
|
|
@ -9,6 +9,8 @@ service TextGenerationService {
|
|||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||
/// Empties batch cache
|
||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||
/// Remove requests from a cached batch
|
||||
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||
/// Prefill batch and decode first token
|
||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
|
@ -89,6 +91,8 @@ message Batch {
|
|||
repeated Request requests = 2;
|
||||
/// Batch size (==len(requests))
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
}
|
||||
|
||||
enum FinishReason {
|
||||
|
@ -134,6 +138,19 @@ message Generation {
|
|||
GeneratedText generated_text = 7;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
/// Batch ID
|
||||
uint64 batch_id = 1;
|
||||
/// Requests to keep
|
||||
repeated Request keep_requests = 2;
|
||||
}
|
||||
|
||||
message FilterBatchResponse {
|
||||
/// Filtered Batch (cached)
|
||||
Batch batch = 1;
|
||||
}
|
||||
|
||||
|
||||
message PrefillRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
|
|
|
@ -70,6 +70,22 @@ impl Client {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
keep_requests: Vec<Request>,
|
||||
) -> Result<Option<Batch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
keep_requests,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/// Multi shard Client
|
||||
use crate::Result;
|
||||
use crate::{Batch, Client, Generation, ShardInfo};
|
||||
use crate::{Batch, Client, Generation, Request, ShardInfo};
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
@ -59,6 +59,22 @@ impl ShardedClient {
|
|||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
keep_requests: Vec<Request>,
|
||||
) -> Result<Option<Batch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
|
|
|
@ -39,12 +39,14 @@ impl Infer {
|
|||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
validation: Validation,
|
||||
max_batch_size: usize,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new();
|
||||
let queue = Queue::new(requires_padding);
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
|
@ -52,7 +54,8 @@ impl Infer {
|
|||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client,
|
||||
max_batch_size,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
queue.clone(),
|
||||
shared.clone(),
|
||||
|
@ -232,18 +235,12 @@ impl Infer {
|
|||
/// Batches requests and sends them to the inference server
|
||||
async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
max_batch_size: usize,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
queue: Queue,
|
||||
shared: Arc<Shared>,
|
||||
) {
|
||||
// Minimum batch size after which we try to add more requests
|
||||
let limit_min_batch_size = if max_batch_size > 1 {
|
||||
(max_batch_size / 2) as u32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
|
@ -252,7 +249,9 @@ async fn batching_task(
|
|||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
||||
while let Some((mut entries, batch, span)) =
|
||||
queue.next_batch(None, max_batch_total_tokens).await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
|
@ -263,24 +262,33 @@ async fn batching_task(
|
|||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
||||
|
||||
// If the current batch is too small, we try to add more requests to it
|
||||
if batch_size <= limit_min_batch_size {
|
||||
let min_size = match waiting_tokens {
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
_ if waiting_tokens >= max_waiting_tokens => None,
|
||||
// Minimum size criteria
|
||||
_ => Some(limit_min_batch_size as usize),
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens - batch_max_tokens;
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||
.await
|
||||
if let Some((mut new_entries, new_batch, span)) =
|
||||
queue.next_batch(min_size, token_budget).await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
||||
} else {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
|
@ -304,7 +312,7 @@ async fn batching_task(
|
|||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
|
@ -325,6 +333,7 @@ async fn batching_task(
|
|||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -341,22 +350,11 @@ async fn prefill(
|
|||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = match next_batch {
|
||||
None => None,
|
||||
Some(batch) => {
|
||||
let id = batch.id;
|
||||
let next_batch = filter_batch(batch, entries);
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
if next_batch.is_none() {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
next_batch
|
||||
}
|
||||
};
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||
|
@ -384,22 +382,11 @@ async fn decode(
|
|||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = match next_batch {
|
||||
None => None,
|
||||
Some(batch) => {
|
||||
let id = batch.id;
|
||||
let next_batch = filter_batch(batch, entries);
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
if next_batch.is_none() {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
next_batch
|
||||
}
|
||||
};
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
|
@ -419,14 +406,35 @@ async fn decode(
|
|||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
|
||||
batch.requests.retain(|r| entries.contains_key(&r.id));
|
||||
let size = batch.requests.len();
|
||||
if size == 0 {
|
||||
return None;
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<Batch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.requests.retain(|r| entries.contains_key(&r.id));
|
||||
|
||||
if batch.requests.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.requests).await.unwrap()
|
||||
}
|
||||
batch.size = size as u32;
|
||||
Some(batch)
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
|
|
|
@ -31,8 +31,12 @@ struct Args {
|
|||
max_input_length: usize,
|
||||
#[clap(default_value = "1512", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "32000", long, env)]
|
||||
max_batch_total_tokens: u32,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
|
@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
|
|||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
waiting_served_ratio,
|
||||
mut max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
|
@ -119,6 +125,12 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.block_on(async {
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size{
|
||||
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
|
||||
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
|
||||
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
|
||||
}
|
||||
|
||||
if tokenizer.is_none() {
|
||||
tracing::warn!(
|
||||
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
||||
|
@ -174,7 +186,8 @@ fn main() -> Result<(), std::io::Error> {
|
|||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
|
|
|
@ -2,7 +2,6 @@ use crate::infer::InferError;
|
|||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::oneshot;
|
||||
|
@ -34,12 +33,12 @@ pub(crate) struct Queue {
|
|||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new() -> Self {
|
||||
pub(crate) fn new(requires_padding: bool) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(queue_receiver));
|
||||
tokio::spawn(queue_task(requires_padding, queue_receiver));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
|
@ -59,7 +58,7 @@ impl Queue {
|
|||
pub(crate) async fn next_batch(
|
||||
&self,
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
// Create response channel
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
|
@ -68,7 +67,7 @@ impl Queue {
|
|||
self.queue_sender
|
||||
.send(QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
token_budget,
|
||||
response_sender,
|
||||
span: Span::current(),
|
||||
})
|
||||
|
@ -80,20 +79,24 @@ impl Queue {
|
|||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
|
||||
let mut state = State::new();
|
||||
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) {
|
||||
let mut state = State::new(requires_padding);
|
||||
|
||||
while let Ok(cmd) = receiver.recv_async().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(entry));
|
||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
token_budget,
|
||||
response_sender,
|
||||
span,
|
||||
} => span.in_scope(|| {
|
||||
let next_batch = state.next_batch(min_size, max_size);
|
||||
let next_batch = state.next_batch(min_size, token_budget);
|
||||
response_sender.send(next_batch).unwrap_or(());
|
||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
@ -110,14 +113,18 @@ struct State {
|
|||
|
||||
/// Id of the next batch
|
||||
next_batch_id: u64,
|
||||
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new() -> Self {
|
||||
fn new(requires_padding: bool) -> Self {
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -130,11 +137,10 @@ impl State {
|
|||
// Push entry in the queue
|
||||
self.entries.push_back((self.next_id, entry));
|
||||
self.next_id += 1;
|
||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
fn next_batch(&mut self, min_size: Option<usize>, max_size: usize) -> Option<NextBatch> {
|
||||
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
@ -146,17 +152,19 @@ impl State {
|
|||
}
|
||||
}
|
||||
|
||||
let max_batch_size = min(self.entries.len(), max_size);
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(max_batch_size);
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
// Iterate on buffer
|
||||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
|
@ -165,6 +173,24 @@ impl State {
|
|||
continue;
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||
} else {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
}
|
||||
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
|
||||
if (prefill_tokens + decode_tokens) > token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
|
@ -184,21 +210,29 @@ impl State {
|
|||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
|
||||
if batch_requests.len() == max_batch_size {
|
||||
// We have enough requests in the batch
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
||||
|
||||
// Maybe all entries were dropped because their channel were closed
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Final batch size once we dropped entries
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch_requests.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for r in batch_requests.into_iter().rev() {
|
||||
let id = r.id;
|
||||
let entry = batch_entries.remove(&id).unwrap();
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Final batch size
|
||||
let size = batch_requests.len() as u32;
|
||||
next_batch_span.record("batch_size", size);
|
||||
|
||||
|
@ -206,11 +240,13 @@ impl State {
|
|||
id: self.next_batch_id,
|
||||
requests: batch_requests,
|
||||
size,
|
||||
max_tokens: (prefill_tokens + decode_tokens),
|
||||
};
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
}
|
||||
|
@ -222,7 +258,7 @@ enum QueueCommand {
|
|||
Append(Entry, Span),
|
||||
NextBatch {
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
token_budget: u32,
|
||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||
span: Span,
|
||||
},
|
||||
|
@ -243,6 +279,7 @@ mod tests {
|
|||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: "".to_string(),
|
||||
input_length: 0,
|
||||
truncate: 0,
|
||||
parameters: NextTokenChooserParameters {
|
||||
temperature: 0.0,
|
||||
|
@ -256,7 +293,7 @@ mod tests {
|
|||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: false,
|
||||
max_new_tokens: 0,
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
},
|
||||
|
@ -271,7 +308,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new();
|
||||
let mut state = State::new(false);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
|
@ -287,7 +324,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new();
|
||||
let mut state = State::new(false);
|
||||
|
||||
assert!(state.next_batch(None, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), 1).is_none());
|
||||
|
@ -295,7 +332,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new();
|
||||
let mut state = State::new(false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -326,8 +363,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let mut state = State::new();
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -360,14 +397,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new();
|
||||
let queue = Queue::new(false);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new();
|
||||
let queue = Queue::new(false);
|
||||
|
||||
assert!(queue.next_batch(None, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), 1).await.is_none());
|
||||
|
@ -375,7 +412,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new();
|
||||
let queue = Queue::new(false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -397,8 +434,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new();
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -423,7 +460,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new();
|
||||
let queue = Queue::new(false);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
|
|
@ -511,7 +511,8 @@ pub async fn run(
|
|||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: usize,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
client: ShardedClient,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
|
@ -571,9 +572,11 @@ pub async fn run(
|
|||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
max_batch_size,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
|
@ -604,7 +607,7 @@ pub async fn run(
|
|||
.collect();
|
||||
// Batch size buckets
|
||||
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
||||
let batch_size_buckets: Vec<f64> = (0..max_batch_size).map(|x| (x + 1) as f64).collect();
|
||||
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
||||
|
||||
// Prometheus handler
|
||||
let builder = PrometheusBuilder::new()
|
||||
|
|
|
@ -69,7 +69,7 @@ impl Validation {
|
|||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: u32,
|
||||
) -> Result<String, ValidationError> {
|
||||
) -> Result<(String, usize), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
|
@ -105,25 +105,24 @@ impl Validation {
|
|||
}
|
||||
|
||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||
Ok(inputs)
|
||||
Ok((inputs, input_length))
|
||||
}
|
||||
// Return inputs without validation
|
||||
else {
|
||||
// In this case, we don't know the real length in tokens of the inputs
|
||||
// However, the inputs will be truncated by the python servers
|
||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
|
||||
// Validate MaxNewTokens
|
||||
if (truncate.unwrap_or(self.max_input_length) as u32 + max_new_tokens)
|
||||
> self.max_total_tokens as u32
|
||||
{
|
||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||
return Err(ValidationError::MaxNewTokens(
|
||||
self.max_total_tokens - self.max_input_length,
|
||||
max_new_tokens,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(inputs)
|
||||
Ok((inputs, input_length))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -238,7 +237,7 @@ impl Validation {
|
|||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let inputs = self
|
||||
let (inputs, input_length) = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
|
@ -262,6 +261,7 @@ impl Validation {
|
|||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
input_length: input_length as u32,
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
|
@ -333,6 +333,7 @@ type TokenizerRequest = (
|
|||
#[derive(Debug)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: String,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub parameters: NextTokenChooserParameters,
|
||||
pub stopping_parameters: StoppingCriteriaParameters,
|
||||
|
|
|
@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
|
|
@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
stopping_criterias = (
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
|
||||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
|
|
@ -46,6 +46,9 @@ class CausalLMBatch(Batch):
|
|||
max_input_length: int
|
||||
padding_right_offset: int
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
# Past metadata
|
||||
keys_head_dim_last: bool = True
|
||||
|
||||
|
@ -54,6 +57,7 @@ class CausalLMBatch(Batch):
|
|||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -73,6 +77,7 @@ class CausalLMBatch(Batch):
|
|||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
|
@ -84,6 +89,7 @@ class CausalLMBatch(Batch):
|
|||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
@ -112,6 +118,8 @@ class CausalLMBatch(Batch):
|
|||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
|
@ -128,6 +136,7 @@ class CausalLMBatch(Batch):
|
|||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -150,6 +159,7 @@ class CausalLMBatch(Batch):
|
|||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
|
@ -168,19 +178,23 @@ class CausalLMBatch(Batch):
|
|||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
new_padding_right_offset = max(
|
||||
new_padding_right_offset,
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
total_remaining_decode_tokens += remaining_decode_tokens
|
||||
new_padding_right_offset = max(
|
||||
new_padding_right_offset, remaining_decode_tokens
|
||||
)
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
input_ids = self.input_ids[keep_indices]
|
||||
position_ids = self.position_ids[keep_indices]
|
||||
self.attention_mask = self.attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_input_length):
|
||||
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset,
|
||||
-(self.padding_right_offset + max_input_length) : (
|
||||
self.attention_mask.shape[1] - self.padding_right_offset
|
||||
)
|
||||
+ new_padding_right_offset,
|
||||
]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
|
@ -203,6 +217,8 @@ class CausalLMBatch(Batch):
|
|||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||
del past_values
|
||||
|
||||
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
|
||||
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids = input_ids
|
||||
|
@ -215,6 +231,7 @@ class CausalLMBatch(Batch):
|
|||
self.stopping_criterias = stopping_criterias
|
||||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
return self
|
||||
|
||||
|
@ -239,6 +256,7 @@ class CausalLMBatch(Batch):
|
|||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
input_ids = None
|
||||
|
@ -314,7 +332,8 @@ class CausalLMBatch(Batch):
|
|||
# And ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
batch.past_key_values = [
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
||||
for layer in batch.past_key_values
|
||||
]
|
||||
elif batch.past_key_values[0][0].shape == 3:
|
||||
for layer in batch.past_key_values:
|
||||
|
@ -322,6 +341,10 @@ class CausalLMBatch(Batch):
|
|||
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
||||
|
||||
start_index = end_index
|
||||
# Add eventual padding tokens that were added while concatenating
|
||||
max_tokens += batch.max_tokens + (
|
||||
max_input_length - batch.max_input_length
|
||||
) * len(batch)
|
||||
|
||||
first_past_kvs = batches[0].past_key_values
|
||||
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
||||
|
@ -371,7 +394,9 @@ class CausalLMBatch(Batch):
|
|||
|
||||
start_index = end_index
|
||||
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape)
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||
padded_past_values_shape
|
||||
)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_values = batch.past_key_values[j][1]
|
||||
|
@ -387,6 +412,7 @@ class CausalLMBatch(Batch):
|
|||
] = past_values[:, :, -past_seq_len:, :]
|
||||
del past_values
|
||||
|
||||
# Update values
|
||||
start_index = end_index
|
||||
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
|
@ -408,6 +434,7 @@ class CausalLMBatch(Batch):
|
|||
max_input_length=max_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -56,9 +56,15 @@ class FlashCausalLMBatch(Batch):
|
|||
# Constant shared tensor, ref here just so that it's accessible in concatentate()
|
||||
past_pad: Optional[torch.Tensor]
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id, requests=self.requests, size=len(self)
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -86,6 +92,8 @@ class FlashCausalLMBatch(Batch):
|
|||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
||||
max_tokens = 0
|
||||
|
||||
# Parse batch
|
||||
for i, r in enumerate(pb.requests):
|
||||
# request id -> idx in list mapping
|
||||
|
@ -115,16 +123,20 @@ class FlashCausalLMBatch(Batch):
|
|||
cu_seqlens.append(cumulative_length + input_length)
|
||||
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
all_input_ids_tensor.append(
|
||||
F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
|
||||
)
|
||||
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
max_tokens += input_length + max_new_tokens
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
|
@ -143,6 +155,7 @@ class FlashCausalLMBatch(Batch):
|
|||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
past_pad=None,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -177,6 +190,8 @@ class FlashCausalLMBatch(Batch):
|
|||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
max_tokens = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
@ -203,9 +218,14 @@ class FlashCausalLMBatch(Batch):
|
|||
token_offsets.append(self.token_offsets[idx])
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criterias.append(self.stopping_criterias[idx])
|
||||
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
cumulative_length += request_input_length
|
||||
max_tokens += request_input_length + (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
||||
if single_request:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
|
@ -241,6 +261,7 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -269,6 +290,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
cumulative_length = 0
|
||||
max_tokens = 0
|
||||
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
|
@ -310,6 +332,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Update
|
||||
cumulative_length += batch.cu_seqlens[-1]
|
||||
cumulative_batch_size += len(batch)
|
||||
max_tokens += batch.max_tokens
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=batches[0].batch_id,
|
||||
|
@ -328,6 +351,7 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -101,6 +101,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
|
@ -113,6 +114,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
@ -141,6 +143,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
|
@ -157,6 +161,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -54,10 +54,16 @@ class Seq2SeqLMBatch(Batch):
|
|||
max_decoder_input_length: int
|
||||
padding_right_offset: int
|
||||
|
||||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id, requests=self.requests, size=len(self)
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -80,6 +86,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
# Parse batch
|
||||
max_truncation = 0
|
||||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
inputs.append(r.inputs)
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
@ -92,6 +99,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
@ -117,6 +125,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
|
@ -137,6 +147,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
max_input_length=max_input_length.item(),
|
||||
max_decoder_input_length=1,
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -166,6 +177,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
max_decoder_input_length = 0
|
||||
padding_right_offset = 0
|
||||
|
||||
remaining_decode_tokens = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
@ -187,11 +200,16 @@ class Seq2SeqLMBatch(Batch):
|
|||
)
|
||||
padding_right_offset = max(
|
||||
padding_right_offset,
|
||||
self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens
|
||||
self.stopping_criterias[idx].max_new_tokens
|
||||
- self.stopping_criterias[idx].current_tokens,
|
||||
)
|
||||
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criterias.append(self.stopping_criterias[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
remaining_decode_tokens += (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||
self.decoder_input_ids = self.decoder_input_ids[keep_indices]
|
||||
|
@ -199,15 +217,21 @@ class Seq2SeqLMBatch(Batch):
|
|||
if self.decoder_attention_mask is not None:
|
||||
self.decoder_attention_mask = self.decoder_attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_decoder_input_length):
|
||||
(self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset,
|
||||
-(self.padding_right_offset + max_decoder_input_length) : (
|
||||
self.decoder_attention_mask.shape[1] - self.padding_right_offset
|
||||
)
|
||||
+ padding_right_offset,
|
||||
]
|
||||
|
||||
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:]
|
||||
self.encoder_last_hidden_state = self.encoder_last_hidden_state[
|
||||
keep_indices, -max_input_length:
|
||||
]
|
||||
|
||||
# Ensure that past_key_values tensors can be updated in-place
|
||||
if type(self.past_key_values[0]) == tuple:
|
||||
self.past_key_values = [[t for t in layer] for layer in self.past_key_values]
|
||||
self.past_key_values = [
|
||||
[t for t in layer] for layer in self.past_key_values
|
||||
]
|
||||
|
||||
decoder_past_seq_len = max_decoder_input_length - 1
|
||||
for layer in self.past_key_values:
|
||||
|
@ -216,6 +240,11 @@ class Seq2SeqLMBatch(Batch):
|
|||
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||
|
||||
max_tokens = (
|
||||
len(requests) * (max_input_length + max_decoder_input_length)
|
||||
+ remaining_decode_tokens
|
||||
)
|
||||
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
self.input_ids = None
|
||||
|
@ -229,10 +258,10 @@ class Seq2SeqLMBatch(Batch):
|
|||
self.max_input_length = max_input_length
|
||||
self.max_decoder_input_length = max_decoder_input_length
|
||||
self.padding_right_offset = padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
||||
|
@ -261,6 +290,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
token_offsets = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
attention_mask = None
|
||||
|
@ -363,9 +393,18 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
# Ensure that we can update tensors in-place
|
||||
if type(batch.past_key_values[0]) == tuple:
|
||||
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]
|
||||
batch.past_key_values = [
|
||||
[t for t in layer] for layer in batch.past_key_values
|
||||
]
|
||||
|
||||
start_index = end_index
|
||||
# Add eventual padding tokens that were added while concatenating
|
||||
max_tokens += batch.max_tokens + (
|
||||
max_input_length
|
||||
- batch.max_input_length
|
||||
+ max_decoder_input_length
|
||||
- batch.max_decoder_input_length
|
||||
) * len(batch)
|
||||
|
||||
# Determine shapes for new past kv tensors
|
||||
first_past_kvs = batches[0].past_key_values
|
||||
|
@ -404,9 +443,9 @@ class Seq2SeqLMBatch(Batch):
|
|||
end_index = start_index + len(batch)
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_decoder_input_length - 1
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = t[:, :, -past_seq_len:, :]
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
|
||||
:, :, -past_seq_len:, :
|
||||
]
|
||||
del t
|
||||
|
||||
start_index = end_index
|
||||
|
@ -426,8 +465,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
end_index = start_index + len(batch)
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -batch.max_input_length:, :
|
||||
] = t[:, :, -batch.max_input_length:, :]
|
||||
start_index:end_index, :, -batch.max_input_length :, :
|
||||
] = t[:, :, -batch.max_input_length :, :]
|
||||
del t
|
||||
|
||||
start_index = end_index
|
||||
|
@ -452,6 +491,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -41,6 +41,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
torch.cuda.empty_cache()
|
||||
return generate_pb2.ClearCacheResponse()
|
||||
|
||||
async def FilterBatch(self, request, context):
|
||||
batch = self.cache.pop(request.batch_id)
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||
filtered_batch = batch.filter(request.keep_requests)
|
||||
self.cache.set(filtered_batch)
|
||||
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.device
|
||||
|
@ -63,8 +72,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
batch = self.cache.pop(batch_pb.id)
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
|
||||
batch = batch.filter(batch_pb.requests)
|
||||
if batch is not None:
|
||||
batches.append(batch)
|
||||
|
||||
if len(batches) == 0:
|
||||
|
|
Loading…
Reference in New Issue