feat(router): use number of tokens in batch as input for dynamic batching (#226)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
OlivierDehaene 2023-04-24 17:59:00 +02:00 committed by GitHub
parent 98a3e0d135
commit ebc74d5666
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 399 additions and 172 deletions

View File

@ -39,8 +39,12 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
@ -93,6 +97,8 @@ fn main() -> ExitCode {
max_input_length,
max_total_tokens,
max_batch_size,
max_batch_total_tokens,
waiting_served_ratio,
max_waiting_tokens,
port,
shard_uds_path,
@ -380,8 +386,8 @@ fn main() -> ExitCode {
max_input_length.to_string(),
"--max-total-tokens".to_string(),
max_total_tokens.to_string(),
"--max-batch-size".to_string(),
max_batch_size.to_string(),
"--waiting-served-ratio".to_string(),
waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
max_waiting_tokens.to_string(),
"--port".to_string(),
@ -392,6 +398,15 @@ fn main() -> ExitCode {
model_id,
];
// Deprecate max_batch_size
if let Some(max_batch_size) = max_batch_size {
argv.push("--max-batch-size".to_string());
argv.push(max_batch_size.to_string())
} else {
argv.push("--max-batch-total-tokens".to_string());
argv.push(max_batch_total_tokens.to_string())
}
// Model optional revision
if let Some(ref revision) = revision {
argv.push("--revision".to_string());

View File

@ -9,6 +9,8 @@ service TextGenerationService {
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
@ -89,6 +91,8 @@ message Batch {
repeated Request requests = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
enum FinishReason {
@ -134,6 +138,19 @@ message Generation {
GeneratedText generated_text = 7;
}
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated Request keep_requests = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
Batch batch = 1;
}
message PrefillRequest {
/// Batch
Batch batch = 1;

View File

@ -70,6 +70,22 @@ impl Client {
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
keep_requests: Vec<Request>,
) -> Result<Option<Batch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
keep_requests,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch

View File

@ -1,6 +1,6 @@
/// Multi shard Client
use crate::Result;
use crate::{Batch, Client, Generation, ShardInfo};
use crate::{Batch, Client, Generation, Request, ShardInfo};
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
@ -59,6 +59,22 @@ impl ShardedClient {
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
keep_requests: Vec<Request>,
) -> Result<Option<Batch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch

View File

@ -39,12 +39,14 @@ impl Infer {
pub(crate) fn new(
client: ShardedClient,
validation: Validation,
max_batch_size: usize,
waiting_served_ratio: f32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
requires_padding: bool,
) -> Self {
// Infer shared state
let queue = Queue::new();
let queue = Queue::new(requires_padding);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
@ -52,7 +54,8 @@ impl Infer {
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client,
max_batch_size,
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
queue.clone(),
shared.clone(),
@ -232,18 +235,12 @@ impl Infer {
/// Batches requests and sends them to the inference server
async fn batching_task(
mut client: ShardedClient,
max_batch_size: usize,
waiting_served_ratio: f32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
queue: Queue,
shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = if max_batch_size > 1 {
(max_batch_size / 2) as u32
} else {
0
};
// Infinite loop
loop {
// Wait for a notification from the Infer struct
@ -252,7 +249,9 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
while let Some((mut entries, batch, span)) =
queue.next_batch(None, max_batch_total_tokens).await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span)
.await;
@ -263,24 +262,33 @@ async fn batching_task(
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
let min_size = match waiting_tokens {
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens - batch_max_tokens;
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_size - batch_size as usize)
.await
if let Some((mut new_entries, new_batch, span)) =
queue.next_batch(min_size, token_budget).await
{
// Tracking metrics
if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
} else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
@ -304,7 +312,7 @@ async fn batching_task(
batches.push(new_cached_batch);
}
}
}
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
@ -325,6 +333,7 @@ async fn batching_task(
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size", 0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
}
}
}
@ -341,22 +350,11 @@ async fn prefill(
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
};
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
@ -384,22 +382,11 @@ async fn decode(
match client.decode(batches).await {
Ok((generations, next_batch)) => {
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
};
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
@ -419,14 +406,35 @@ async fn decode(
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
batch.requests.retain(|r| entries.contains_key(&r.id));
let size = batch.requests.len();
if size == 0 {
return None;
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<Batch>,
entries: &IntMap<u64, Entry>,
) -> Option<Batch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.requests.retain(|r| entries.contains_key(&r.id));
if batch.requests.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.requests).await.unwrap()
}
batch.size = size as u32;
Some(batch)
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`

View File

@ -31,8 +31,12 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
max_input_length,
max_total_tokens,
max_batch_size,
waiting_served_ratio,
mut max_batch_total_tokens,
max_waiting_tokens,
port,
master_shard_uds_path,
@ -119,6 +125,12 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async {
init_logging(otlp_endpoint, json_output);
if let Some(max_batch_size) = max_batch_size{
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
}
if tokenizer.is_none() {
tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
@ -174,7 +186,8 @@ fn main() -> Result<(), std::io::Error> {
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
sharded_client,
tokenizer,

View File

@ -2,7 +2,6 @@ use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::{Batch, Request};
use tokio::sync::oneshot;
@ -34,12 +33,12 @@ pub(crate) struct Queue {
}
impl Queue {
pub(crate) fn new() -> Self {
pub(crate) fn new(requires_padding: bool) -> Self {
// Create channel
let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task
tokio::spawn(queue_task(queue_receiver));
tokio::spawn(queue_task(requires_padding, queue_receiver));
Self { queue_sender }
}
@ -59,7 +58,7 @@ impl Queue {
pub(crate) async fn next_batch(
&self,
min_size: Option<usize>,
max_size: usize,
token_budget: u32,
) -> Option<NextBatch> {
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
@ -68,7 +67,7 @@ impl Queue {
self.queue_sender
.send(QueueCommand::NextBatch {
min_size,
max_size,
token_budget,
response_sender,
span: Span::current(),
})
@ -80,20 +79,24 @@ impl Queue {
}
// Background task responsible of the queue state
async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
let mut state = State::new();
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) {
let mut state = State::new(requires_padding);
while let Ok(cmd) = receiver.recv_async().await {
match cmd {
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(entry));
metrics::increment_gauge!("tgi_queue_size", 1.0);
}
QueueCommand::NextBatch {
min_size,
max_size,
token_budget,
response_sender,
span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, max_size);
let next_batch = state.next_batch(min_size, token_budget);
response_sender.send(next_batch).unwrap_or(());
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}),
}
}
@ -110,14 +113,18 @@ struct State {
/// Id of the next batch
next_batch_id: u64,
/// Whether the model is using padding
requires_padding: bool,
}
impl State {
fn new() -> Self {
fn new(requires_padding: bool) -> Self {
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
requires_padding,
}
}
@ -130,11 +137,10 @@ impl State {
// Push entry in the queue
self.entries.push_back((self.next_id, entry));
self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0);
}
// Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, max_size: usize) -> Option<NextBatch> {
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
if self.entries.is_empty() {
return None;
}
@ -146,17 +152,19 @@ impl State {
}
}
let max_batch_size = min(self.entries.len(), max_size);
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(max_batch_size);
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
// Iterate on buffer
let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0;
// Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
@ -165,6 +173,24 @@ impl State {
continue;
}
if self.requires_padding {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else {
prefill_tokens += entry.request.input_length;
}
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
if (prefill_tokens + decode_tokens) > token_budget {
// Entry is over budget
// Add it back to the front
self.entries.push_front((id, entry));
break;
}
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
@ -184,21 +210,29 @@ impl State {
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
if batch_requests.len() == max_batch_size {
// We have enough requests in the batch
break;
}
}
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
// Maybe all entries were dropped because their channel were closed
// Empty batch
if batch_requests.is_empty() {
return None;
}
// Final batch size once we dropped entries
// Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch_requests.len() < min_size {
// Add back entries to the queue in the correct order
for r in batch_requests.into_iter().rev() {
let id = r.id;
let entry = batch_entries.remove(&id).unwrap();
self.entries.push_front((id, entry));
}
return None;
}
}
// Final batch size
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
@ -206,11 +240,13 @@ impl State {
id: self.next_batch_id,
requests: batch_requests,
size,
max_tokens: (prefill_tokens + decode_tokens),
};
// Increment batch id
self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}
}
@ -222,7 +258,7 @@ enum QueueCommand {
Append(Entry, Span),
NextBatch {
min_size: Option<usize>,
max_size: usize,
token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
},
@ -243,6 +279,7 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: "".to_string(),
input_length: 0,
truncate: 0,
parameters: NextTokenChooserParameters {
temperature: 0.0,
@ -256,7 +293,7 @@ mod tests {
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,
max_new_tokens: 0,
max_new_tokens: 1,
stop_sequences: vec![],
},
},
@ -271,7 +308,7 @@ mod tests {
#[test]
fn test_append() {
let mut state = State::new();
let mut state = State::new(false);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
@ -287,7 +324,7 @@ mod tests {
#[test]
fn test_next_batch_empty() {
let mut state = State::new();
let mut state = State::new(false);
assert!(state.next_batch(None, 1).is_none());
assert!(state.next_batch(Some(1), 1).is_none());
@ -295,7 +332,7 @@ mod tests {
#[test]
fn test_next_batch_min_size() {
let mut state = State::new();
let mut state = State::new(false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -326,8 +363,8 @@ mod tests {
}
#[test]
fn test_next_batch_max_size() {
let mut state = State::new();
fn test_next_batch_token_budget() {
let mut state = State::new(false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -360,14 +397,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new();
let queue = Queue::new(false);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new();
let queue = Queue::new(false);
assert!(queue.next_batch(None, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1).await.is_none());
@ -375,7 +412,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new();
let queue = Queue::new(false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -397,8 +434,8 @@ mod tests {
}
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new();
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -423,7 +460,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new();
let queue = Queue::new(false);
let (entry, _) = default_entry();
queue.append(entry);

View File

@ -511,7 +511,8 @@ pub async fn run(
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
max_batch_size: usize,
waiting_served_ratio: f32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
client: ShardedClient,
tokenizer: Option<Tokenizer>,
@ -571,9 +572,11 @@ pub async fn run(
let infer = Infer::new(
client,
validation,
max_batch_size,
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_concurrent_requests,
shard_info.requires_padding,
);
// Duration buckets
@ -604,7 +607,7 @@ pub async fn run(
.collect();
// Batch size buckets
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..max_batch_size).map(|x| (x + 1) as f64).collect();
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Prometheus handler
let builder = PrometheusBuilder::new()

View File

@ -69,7 +69,7 @@ impl Validation {
inputs: String,
truncate: Option<usize>,
max_new_tokens: u32,
) -> Result<String, ValidationError> {
) -> Result<(String, usize), ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
@ -105,25 +105,24 @@ impl Validation {
}
metrics::histogram!("tgi_request_input_length", input_length as f64);
Ok(inputs)
Ok((inputs, input_length))
}
// Return inputs without validation
else {
// In this case, we don't know the real length in tokens of the inputs
// However, the inputs will be truncated by the python servers
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
let input_length = truncate.unwrap_or(self.max_input_length);
// Validate MaxNewTokens
if (truncate.unwrap_or(self.max_input_length) as u32 + max_new_tokens)
> self.max_total_tokens as u32
{
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length,
max_new_tokens,
));
}
Ok(inputs)
Ok((inputs, input_length))
}
}
@ -238,7 +237,7 @@ impl Validation {
.unwrap_or(Ok(None))?;
// Validate inputs
let inputs = self
let (inputs, input_length) = self
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
@ -262,6 +261,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters,
stopping_parameters,
@ -333,6 +333,7 @@ type TokenizerRequest = (
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub input_length: u32,
pub truncate: u32,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,

View File

@ -181,9 +181,7 @@ def test_causal_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
stopping_criterias[0].max_new_tokens
- stopping_criterias[1].max_new_tokens
- 1
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)

View File

@ -174,14 +174,14 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_causal_lm_batch.stopping_criterias.copy()
stopping_criterias = (
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
stopping_criterias[0].max_new_tokens
- stopping_criterias[1].max_new_tokens
- 1
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)

View File

@ -46,6 +46,9 @@ class CausalLMBatch(Batch):
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
@ -54,6 +57,7 @@ class CausalLMBatch(Batch):
id=self.batch_id,
requests=self.requests,
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
@ -73,6 +77,7 @@ class CausalLMBatch(Batch):
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
@ -84,6 +89,7 @@ class CausalLMBatch(Batch):
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -112,6 +118,8 @@ class CausalLMBatch(Batch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -128,6 +136,7 @@ class CausalLMBatch(Batch):
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
@tracer.start_as_current_span("filter")
@ -150,6 +159,7 @@ class CausalLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, r in enumerate(requests):
@ -168,19 +178,23 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
new_padding_right_offset = max(
new_padding_right_offset,
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length):
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
# Ensure that past_key_values tensors can be updated in-place
@ -203,6 +217,8 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
@ -215,6 +231,7 @@ class CausalLMBatch(Batch):
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
@ -239,6 +256,7 @@ class CausalLMBatch(Batch):
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
# Batch tensors
input_ids = None
@ -314,7 +332,8 @@ class CausalLMBatch(Batch):
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif batch.past_key_values[0][0].shape == 3:
for layer in batch.past_key_values:
@ -322,6 +341,10 @@ class CausalLMBatch(Batch):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
@ -371,7 +394,9 @@ class CausalLMBatch(Batch):
start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(padded_past_values_shape)
padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
@ -387,6 +412,7 @@ class CausalLMBatch(Batch):
] = past_values[:, :, -past_seq_len:, :]
del past_values
# Update values
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
@ -408,6 +434,7 @@ class CausalLMBatch(Batch):
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
)
def __len__(self):

View File

@ -56,9 +56,15 @@ class FlashCausalLMBatch(Batch):
# Constant shared tensor, ref here just so that it's accessible in concatentate()
past_pad: Optional[torch.Tensor]
# Maximum number of tokens this batch will grow to
max_tokens: int
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id, requests=self.requests, size=len(self)
id=self.batch_id,
requests=self.requests,
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
@ -86,6 +92,8 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_length = 0
max_tokens = 0
# Parse batch
for i, r in enumerate(pb.requests):
# request id -> idx in list mapping
@ -115,16 +123,20 @@ class FlashCausalLMBatch(Batch):
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
all_input_ids_tensor.append(
F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
)
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
return cls(
batch_id=pb.id,
@ -143,6 +155,7 @@ class FlashCausalLMBatch(Batch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
past_pad=None,
max_tokens=max_tokens,
)
@tracer.start_as_current_span("filter")
@ -177,6 +190,8 @@ class FlashCausalLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
@ -203,9 +218,14 @@ class FlashCausalLMBatch(Batch):
token_offsets.append(self.token_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
cumulative_length += request_input_length
max_tokens += request_input_length + (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
if single_request:
# Preallocate tensor for bs = 1 case
@ -241,6 +261,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
@classmethod
@ -269,6 +290,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
@ -310,6 +332,7 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
@ -328,6 +351,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
def __len__(self):

View File

@ -101,6 +101,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
@ -113,6 +114,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -141,6 +143,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -157,6 +161,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)

View File

@ -54,10 +54,16 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id, requests=self.requests, size=len(self)
id=self.batch_id,
requests=self.requests,
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
@ -80,6 +86,7 @@ class Seq2SeqLMBatch(Batch):
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
@ -92,6 +99,7 @@ class Seq2SeqLMBatch(Batch):
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -117,6 +125,8 @@ class Seq2SeqLMBatch(Batch):
)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -137,6 +147,7 @@ class Seq2SeqLMBatch(Batch):
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
@tracer.start_as_current_span("filter")
@ -166,6 +177,8 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = 0
padding_right_offset = 0
remaining_decode_tokens = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
@ -187,11 +200,16 @@ class Seq2SeqLMBatch(Batch):
)
padding_right_offset = max(
padding_right_offset,
self.stopping_criterias[idx].max_new_tokens - self.stopping_criterias[idx].current_tokens
self.stopping_criterias[idx].max_new_tokens
- self.stopping_criterias[idx].current_tokens,
)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens += (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self.decoder_input_ids = self.decoder_input_ids[keep_indices]
@ -199,15 +217,21 @@ class Seq2SeqLMBatch(Batch):
if self.decoder_attention_mask is not None:
self.decoder_attention_mask = self.decoder_attention_mask[
keep_indices,
-(self.padding_right_offset + max_decoder_input_length):
(self.decoder_attention_mask.shape[1] - self.padding_right_offset) + padding_right_offset,
-(self.padding_right_offset + max_decoder_input_length) : (
self.decoder_attention_mask.shape[1] - self.padding_right_offset
)
+ padding_right_offset,
]
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:]
self.encoder_last_hidden_state = self.encoder_last_hidden_state[
keep_indices, -max_input_length:
]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
self.past_key_values = [[t for t in layer] for layer in self.past_key_values]
self.past_key_values = [
[t for t in layer] for layer in self.past_key_values
]
decoder_past_seq_len = max_decoder_input_length - 1
for layer in self.past_key_values:
@ -216,6 +240,11 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:]
max_tokens = (
len(requests) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
)
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = None
@ -229,10 +258,10 @@ class Seq2SeqLMBatch(Batch):
self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
self.max_tokens = max_tokens
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
@ -261,6 +290,7 @@ class Seq2SeqLMBatch(Batch):
token_offsets = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
# Batch tensors
attention_mask = None
@ -363,9 +393,18 @@ class Seq2SeqLMBatch(Batch):
# Ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]
batch.past_key_values = [
[t for t in layer] for layer in batch.past_key_values
]
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length
- batch.max_input_length
+ max_decoder_input_length
- batch.max_decoder_input_length
) * len(batch)
# Determine shapes for new past kv tensors
first_past_kvs = batches[0].past_key_values
@ -404,9 +443,9 @@ class Seq2SeqLMBatch(Batch):
end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches
past_seq_len = batch.max_decoder_input_length - 1
padded_past_values[
start_index:end_index, :, -past_seq_len:, :
] = t[:, :, -past_seq_len:, :]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
:, :, -past_seq_len:, :
]
del t
start_index = end_index
@ -426,8 +465,8 @@ class Seq2SeqLMBatch(Batch):
end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches
padded_past_values[
start_index:end_index, :, -batch.max_input_length:, :
] = t[:, :, -batch.max_input_length:, :]
start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :]
del t
start_index = end_index
@ -452,6 +491,7 @@ class Seq2SeqLMBatch(Batch):
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
def __len__(self):

View File

@ -41,6 +41,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
torch.cuda.empty_cache()
return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id)
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.keep_requests)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
@ -63,8 +72,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batch = batch.filter(batch_pb.requests)
if batch is not None:
batches.append(batch)
if len(batches) == 0: