feat(router): drop requests when client closes the channel (#202)

This commit is contained in:
OlivierDehaene 2023-04-20 11:07:40 +02:00 committed by GitHub
parent b6ee0ec7b0
commit 709d8936f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 821 additions and 571 deletions

View File

@ -66,7 +66,7 @@ jobs:
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Login to Azure Container Registry
# if: github.event_name != 'pull_request'
if: github.event_name != 'pull_request'
uses: docker/login-action@v2.1.0
with:
username: ${{ secrets.AZURE_DOCKER_USERNAME }}

118
Cargo.lock generated
View File

@ -42,42 +42,51 @@ dependencies = [
[[package]]
name = "anstream"
version = "0.2.6"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f"
checksum = "9e579a7752471abc2a8268df8b20005e3eadd975f585398f17efcfd8d4927371"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"concolor-override",
"concolor-query",
"colorchoice",
"is-terminal",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "0.3.5"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2"
checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d"
[[package]]
name = "anstyle-parse"
version = "0.1.1"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116"
checksum = "e765fd216e48e067936442276d1d57399e37bce53c264d6fefbe298080cb57ee"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-wincon"
version = "0.2.0"
name = "anstyle-query"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa"
checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b"
dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "anstyle-wincon"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bcd8291a340dd8ac70e18878bc4501dd7b4ff970cfa21c207d36ece51ea88fd"
dependencies = [
"anstyle",
"windows-sys 0.45.0",
"windows-sys 0.48.0",
]
[[package]]
@ -105,7 +114,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -116,7 +125,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -127,9 +136,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.6.13"
version = "0.6.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6539e4565c365448d483967c6dee3eaecb8e87679a17806a831e82b05b903c18"
checksum = "3b32c5ea3aabaf4deb5f5ced2d688ec0844c881c9e6c696a8b769a05fc691e62"
dependencies = [
"async-trait",
"axum-core",
@ -310,9 +319,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.2.1"
version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3"
checksum = "9b802d85aaf3a1cdb02b224ba472ebdea62014fccfcb269b95a4d76443b5ee5a"
dependencies = [
"clap_builder",
"clap_derive",
@ -321,9 +330,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.2.1"
version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f"
checksum = "14a1a858f532119338887a4b8e1af9c60de8249cd7bafd68036a489e261e37b6"
dependencies = [
"anstream",
"anstyle",
@ -341,7 +350,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -351,19 +360,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1"
[[package]]
name = "concolor-override"
name = "colorchoice"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f"
[[package]]
name = "concolor-query"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf"
dependencies = [
"windows-sys 0.45.0",
]
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "console"
@ -794,7 +794,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -868,9 +868,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.3.16"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d"
checksum = "17f8a914c2987b688368b5138aa05321db91f4090cf26118185672ad588bce21"
dependencies = [
"bytes",
"fnv",
@ -966,9 +966,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
[[package]]
name = "hyper"
version = "0.14.25"
version = "0.14.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899"
checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4"
dependencies = [
"bytes",
"futures-channel",
@ -1364,7 +1364,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -1517,7 +1517,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -1787,9 +1787,9 @@ dependencies = [
[[package]]
name = "prost"
version = "0.11.8"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537"
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [
"bytes",
"prost-derive",
@ -1797,9 +1797,9 @@ dependencies = [
[[package]]
name = "prost-build"
version = "0.11.8"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12"
checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270"
dependencies = [
"bytes",
"heck",
@ -1819,9 +1819,9 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.11.8"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b"
checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [
"anyhow",
"itertools 0.10.5",
@ -1832,9 +1832,9 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.11.8"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88"
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
dependencies = [
"prost",
]
@ -2153,14 +2153,14 @@ checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
name = "serde_json"
version = "1.0.95"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"itoa",
"ryu",
@ -2330,9 +2330,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.14"
version = "2.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcf316d5356ed6847742d036f8a39c3b8435cac10bd528a4bd461928a6ab34d5"
checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822"
dependencies = [
"proc-macro2",
"quote",
@ -2450,7 +2450,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -2578,7 +2578,7 @@ checksum = "61a573bdc87985e9d6ddeed1b3d864e8a302c847e40d647746df2f1de209d1ce"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -2928,9 +2928,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "utoipa"
version = "3.2.1"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24e7ee17c9ef094b86e1e04170d90765bd76cb381921dacb4d3e175a267bdae6"
checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98"
dependencies = [
"indexmap",
"serde",
@ -2940,14 +2940,14 @@ dependencies = [
[[package]]
name = "utoipa-gen"
version = "3.2.1"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df6f458e5abc811d44aca28455efc4163fb7565a7af2aa32d17611f3d1d9794d"
checksum = "7ea8ac818da7e746a63285594cce8a96f5e00ee31994e655bd827569cb8b137b"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]

View File

@ -3,6 +3,7 @@ use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream;
use flume::SendError;
use futures::future::try_join_all;
use futures::stream::StreamExt;
use nohash_hasher::IntMap;
@ -11,7 +12,7 @@ use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::{Notify, Semaphore, TryAcquireError};
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
@ -73,9 +74,14 @@ impl Infer {
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> {
) -> Result<
(
OwnedSemaphorePermit,
RecvStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let permit = self
.clone()
.limit_concurrent_requests
@ -104,7 +110,6 @@ impl Infer {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
_permit: permit,
});
// Notify the background task that we have a new entry in the queue that needs
@ -112,7 +117,7 @@ impl Infer {
self.shared.batching_task.notify_one();
// Return stream
Ok(response_rx.into_stream())
Ok((permit, response_rx.into_stream()))
}
/// Add a new request to the queue and return a InferResponse
@ -121,8 +126,8 @@ impl Infer {
&self,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
// Create stream
let mut stream = self.generate_stream(request).await?;
// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
@ -276,12 +281,10 @@ async fn batching_task(
.next_batch(min_size, max_batch_size - batch_size as usize)
.await
{
let new_batch_size = new_batch.size;
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
@ -308,8 +311,7 @@ async fn batching_task(
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
@ -339,7 +341,23 @@ async fn prefill(
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch
@ -361,17 +379,37 @@ async fn decode(
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None
@ -379,6 +417,86 @@ async fn decode(
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
batch.requests.retain(|r| entries.contains_key(&r.id));
let size = batch.requests.len();
if size == 0 {
return None;
}
batch.size = size as u32;
Some(batch)
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))?;
}
Ok(stopped)
}
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
@ -397,65 +515,6 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
});
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
#[instrument(skip_all)]
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
.unwrap_or(());
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let entry = entries
.remove(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))
.unwrap_or(());
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
}
});
}
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message

View File

@ -3,8 +3,9 @@ use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::{Batch, Request};
use tokio::sync::{oneshot, OwnedSemaphorePermit};
use tokio::sync::oneshot;
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -23,8 +24,6 @@ pub(crate) struct Entry {
pub queue_time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
/// Permit
pub _permit: OwnedSemaphorePermit,
}
/// Request Queue
@ -104,7 +103,7 @@ async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
#[derive(Debug)]
struct State {
/// Queue entries organized in a Vec
entries: Vec<(u64, Entry)>,
entries: VecDeque<(u64, Entry)>,
/// Id of the next entry
next_id: u64,
@ -116,7 +115,7 @@ struct State {
impl State {
fn new() -> Self {
Self {
entries: Vec::with_capacity(128),
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
}
@ -129,7 +128,7 @@ impl State {
entry.temp_span = Some(queue_span);
// Push entry in the queue
self.entries.push((self.next_id, entry));
self.entries.push_back((self.next_id, entry));
self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0);
}
@ -147,51 +146,70 @@ impl State {
}
}
let next_batch_size = min(self.entries.len(), max_size);
let max_batch_size = min(self.entries.len(), max_size);
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(next_batch_size);
let mut batch_requests = Vec::with_capacity(max_batch_size);
let mut batch_entries =
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
// Drain next_batch_size entries
self.entries
.drain(..next_batch_size)
.for_each(|(id, mut entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
// Iterate on buffer
while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_disconnected() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
continue;
}
batch_requests.push(Request {
id,
inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
batch_requests.push(Request {
id,
inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
if batch_requests.len() == max_batch_size {
// We have enough requests in the batch
break;
}
}
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
// Maybe all entries were dropped because their channel were closed
if batch_requests.is_empty() {
return None;
}
// Final batch size once we dropped entries
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
let batch = Batch {
id: self.next_batch_id,
requests: batch_requests,
size: next_batch_size as u32,
size,
};
// Increment batch id
self.next_batch_id += 1;
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}
@ -213,17 +231,16 @@ enum QueueCommand {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::Semaphore;
use tracing::info_span;
fn default_entry() -> Entry {
let semaphore = Arc::new(Semaphore::new(1));
let (response_tx, _) = flume::unbounded();
let permit = semaphore.try_acquire_owned().unwrap();
fn default_entry() -> (
Entry,
flume::Receiver<Result<InferStreamResponse, InferError>>,
) {
let (response_tx, receiver_tx) = flume::unbounded();
Entry {
let entry = Entry {
request: ValidGenerateRequest {
inputs: "".to_string(),
truncate: 0,
@ -248,14 +265,14 @@ mod tests {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
_permit: permit,
}
};
(entry, receiver_tx)
}
#[test]
fn test_append() {
let mut state = State::new();
let entry = default_entry();
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
assert_eq!(state.entries.len(), 0);
@ -264,7 +281,7 @@ mod tests {
assert_eq!(state.next_id, 1);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0);
let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 0);
}
@ -279,8 +296,10 @@ mod tests {
#[test]
fn test_next_batch_min_size() {
let mut state = State::new();
state.append(default_entry());
state.append(default_entry());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2);
@ -295,21 +314,24 @@ mod tests {
assert_eq!(state.entries.len(), 0);
assert_eq!(state.next_batch_id, 1);
state.append(default_entry());
let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), 2).is_none());
assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0);
let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 2);
}
#[test]
fn test_next_batch_max_size() {
let mut state = State::new();
state.append(default_entry());
state.append(default_entry());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1);
@ -321,7 +343,8 @@ mod tests {
assert_eq!(state.entries.len(), 1);
assert_eq!(state.next_batch_id, 1);
state.append(default_entry());
let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2);
@ -338,7 +361,8 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new();
queue.append(default_entry());
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
@ -352,8 +376,10 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new();
queue.append(default_entry());
queue.append(default_entry());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2);
@ -364,7 +390,8 @@ mod tests {
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
queue.append(default_entry());
let (entry3, _guard3) = default_entry();
queue.append(entry3);
assert!(queue.next_batch(Some(2), 2).await.is_none());
}
@ -372,8 +399,10 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new();
queue.append(default_entry());
queue.append(default_entry());
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1);
@ -381,7 +410,8 @@ mod tests {
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1);
queue.append(default_entry());
let (entry3, _guard3) = default_entry();
queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2);
@ -390,4 +420,13 @@ mod tests {
assert_eq!(batch.id, 1);
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new();
let (entry, _) = default_entry();
queue.append(entry);
assert!(queue.next_batch(None, 1).await.is_none());
}
}

View File

@ -367,7 +367,8 @@ async fn generate_stream(
let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
match response {

View File

@ -1,6 +1,9 @@
include Makefile-transformers
include Makefile-flash-att
unit-tests:
python -m pytest tests
gen-server:
# Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir

View File

@ -45,8 +45,9 @@ def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
@pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 1
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
@ -70,12 +71,17 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
@ -97,7 +103,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
@ -106,7 +112,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
@ -170,6 +176,8 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
@ -269,6 +277,8 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
@ -290,6 +300,8 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1]])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -44,11 +44,12 @@ def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
@pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 1
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
@ -67,12 +68,17 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
@ -93,7 +99,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert next_batch.all_input_ids[0][-1] == 13
@ -103,7 +109,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2]
@ -168,6 +174,8 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
@ -266,6 +274,8 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
@ -285,6 +295,8 @@ def test_batch_concatenate(
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1]])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -49,8 +49,9 @@ def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
@pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request
req_1.id = 1
req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
@ -72,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0)
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1)
assert len(batch.decoder_input_ids) == default_pb_batch.size
assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None
@ -81,8 +82,8 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
@ -117,9 +118,9 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
)
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2)
assert next_batch.decoder_input_ids[0, 0] == 0
assert next_batch.decoder_input_ids[0, 1] == 259
assert len(next_batch.decoder_input_ids) == len(next_batch)
assert next_batch.all_decoder_input_ids[0][0] == 0
assert next_batch.all_decoder_input_ids[0][1] == 259
assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
@ -128,20 +129,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
[p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
[p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[
p[2].shape == (next_batch.size, 6, sequence_length, 64)
p[2].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
assert all(
[
p[3].shape == (next_batch.size, 6, sequence_length, 64)
p[3].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values
]
)
@ -189,6 +190,8 @@ def test_seq2seq_lm_generate_token_completion_multi(
)
assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
@ -223,7 +226,8 @@ def test_batch_concatenate(
assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
)
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0)
assert next_batch.all_decoder_input_ids[1][0] == 0
assert next_batch.all_decoder_input_ids[2][0] == 0
assert torch.equal(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
)
@ -258,16 +262,16 @@ def test_batch_concatenate(
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
[p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
[p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
[p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
[p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
)
for i, past in enumerate(next_batch.past_key_values):
@ -306,6 +310,8 @@ def test_batch_concatenate(
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
@ -314,6 +320,8 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
class CausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
@ -42,7 +43,6 @@ class CausalLMBatch(Batch):
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int
max_input_length: int
padding_right_offset: int
@ -53,7 +53,7 @@ class CausalLMBatch(Batch):
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
size=len(self),
)
@classmethod
@ -68,11 +68,13 @@ class CausalLMBatch(Batch):
stopping_criterias = []
offsets = []
token_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
offsets.append(None)
token_offsets.append(None)
@ -108,26 +110,91 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
attention_mask = self.attention_mask[keep_indices]
position_ids = self.position_ids[keep_indices]
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_key_values = [
[t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
for layer in self.past_key_values
]
return CausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
padding_right_offset=self.padding_right_offset,
keys_head_dim_last=self.keys_head_dim_last,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
@ -136,12 +203,13 @@ class CausalLMBatch(Batch):
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
@ -167,8 +235,15 @@ class CausalLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + batch.size
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
@ -216,8 +291,8 @@ class CausalLMBatch(Batch):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(batch), -1, *past_values.shape[-2:])
_, num_heads, padded_sequence_length, head_dim = past_values.shape
@ -265,11 +340,12 @@ class CausalLMBatch(Batch):
start_index:end_index, :, -(batch.max_input_length - 1) :, :
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
start_index += batch.size
start_index += len(batch)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -280,7 +356,6 @@ class CausalLMBatch(Batch):
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
@ -364,22 +439,9 @@ class CausalLM(Model):
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_input_ids = []
next_batch_all_input_ids = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
@ -443,16 +505,7 @@ class CausalLM(Model):
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max(
next_batch_max_input_length, new_input_length
)
stopped = False
# Prefill
if stopping_criteria.current_tokens == 1:
@ -484,62 +537,30 @@ class CausalLM(Model):
generations.append(generation)
# Update values
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
if stopped:
return generations, None
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [
[
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
for t in layer
]
for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_position_ids = batch.position_ids
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Decrease right offset
batch.padding_right_offset -= 1
# Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
batch.position_ids = batch.position_ids[:, -1:] + 1
next_batch = CausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
attention_mask=next_batch_attention_mask,
position_ids=next_batch_position_ids,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last,
)
return generations, next_batch
# Update past key values
batch.past_key_values = past
return generations, batch

View File

@ -6,7 +6,7 @@ from torch.nn import functional as F
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Union
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
@ -29,14 +29,16 @@ tracer = trace.get_tracer(__name__)
class FlashCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
input_ids: List[torch.Tensor]
position_ids: List[torch.Tensor]
# cumulative sequence lengths
cu_seqlens: torch.Tensor
cu_seqlens: List[int]
max_seqlen: int
past_key_values: Optional[torch.Tensor]
past_key_values: Optional[List[torch.Tensor]]
# All tokens
all_input_ids: List[List[int]]
@ -62,7 +64,7 @@ class FlashCausalLMBatch(Batch):
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
) -> "FlashCausalLMBatch":
input_ids = []
position_ids = []
cu_seqlens = [0]
@ -73,6 +75,7 @@ class FlashCausalLMBatch(Batch):
token_offsets = []
all_input_ids = []
all_input_ids_tensor = []
requests_idx_mapping = {}
next_token_choosers = []
stopping_criterias = []
@ -81,13 +84,18 @@ class FlashCausalLMBatch(Batch):
cumulative_length = 0
# Parse batch
for r in pb.requests:
for i, r in enumerate(pb.requests):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenizer(
r.inputs, truncation=True, max_length=r.truncate
)["input_ids"]
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
offsets.append(None)
token_offsets.append(None)
all_input_ids.append(tokenized_input)
@ -96,7 +104,9 @@ class FlashCausalLMBatch(Batch):
input_ids.append(tokenized_input)
# Position ids
position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
position_ids.append(
torch.arange(0, input_length, dtype=torch.int32, device=device)
)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
@ -113,13 +123,10 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += input_length
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
@ -134,60 +141,141 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias,
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(requests) == len(self):
return self
# Cumulative length
cumulative_length = 0
# New values after filtering
requests_idx_mapping = {}
input_ids = []
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
# Get length
request_input_length = self.input_lengths[idx]
input_ids.append(self.input_ids[idx])
position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length)
past_key_values.append(self.past_key_values[idx])
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
cumulative_length += request_input_length
return FlashCausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
input_lengths = []
offsets = []
token_offsets = []
all_input_ids = []
all_input_ids_tensor = []
next_token_choosers = []
stopping_criterias = []
requests_idx_mapping = {}
# Batch tensors
input_ids = []
position_ids = []
cu_seqlens = [torch.tensor([0], dtype=torch.int32)]
cu_seqlens = [0]
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
# Cumulative length
cumulative_length = torch.tensor(0)
cumulative_batch_size = 0
cumulative_length = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
input_ids.extend(batch.input_ids)
position_ids.extend(batch.position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
past_key_values.extend(batch.past_key_values)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length)
input_ids.append(batch.input_ids)
position_ids.append(batch.position_ids)
past_key_values.append(batch.past_key_values)
max_seqlen = max(max_seqlen, batch.max_seqlen)
# Update
cumulative_length += batch.cu_seqlens[-1]
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
# Concat on dim=1 as first dim represents the model layers
past_key_values = torch.concat(past_key_values, dim=1)
cu_seqlens = torch.concat(cu_seqlens)
cumulative_batch_size += len(batch)
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
@ -269,38 +357,49 @@ class FlashCausalLM(Model):
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# Better to send to device here to avoid device issues in concatenate
position_ids = batch.position_ids.to(self.device, non_blocking=True)
cu_seqlens = batch.cu_seqlens.to(self.device)
# Shortcut when batch_size == 1
if len(batch) == 1:
input_ids = batch.input_ids[0].view(-1)
past_key_values = (
batch.past_key_values[0] if batch.past_key_values is not None else None
)
else:
# Concatenate tensors
input_ids = torch.cat(batch.input_ids).view(-1)
past_key_values = (
torch.cat(batch.past_key_values, dim=1)
if batch.past_key_values is not None
else None
)
# Concatenate when prefill, torch.tensor when decode
position_ids = (
torch.tensor(batch.position_ids, device=self.device)
if batch.past_key_values is not None
else torch.cat(batch.position_ids)
)
cu_seqlens = torch.tensor(
batch.cu_seqlens, device=self.device, dtype=torch.int32
)
out, present = self.forward(
batch.input_ids,
input_ids,
position_ids,
cu_seqlens,
batch.max_seqlen,
batch.past_key_values,
past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_ids = []
next_batch_position_ids = []
next_batch_cu_seqlens = [0]
next_batch_max_seqlen = 0
next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = []
# Initialize past_key_values in prefill
if batch.past_key_values is None:
batch.past_key_values = [None] * len(batch)
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
@ -329,7 +428,8 @@ class FlashCausalLM(Model):
start_index = cumulative_length
end_index = cumulative_length + input_length
if batch.past_key_values is None:
prefill = stopping_criteria.current_tokens == 0
if prefill:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits = out[start_index:end_index]
@ -348,7 +448,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens
all_input_ids.append(next_token_id_item)
all_input_ids_tensor[input_length] = next_token_id_item
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id_item]
@ -378,32 +477,23 @@ class FlashCausalLM(Model):
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
# CAUTION: generation will be stopped so no need to pad
# This will make the next forward crash if the request does not get filtered
new_input_length = input_length
past = present[:, start_index:end_index]
else:
# Keep request in the batch
next_batch_keep_indices.append(i)
stopped = False
generated_text = None
# Get sequence present
seq_present = present[:, start_index:end_index]
# Pad it for next iter attention
past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
next_batch_past_key_values.append(past)
next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(input_length)
# Cumulative sum
next_batch_cu_seqlens.append(
next_batch_cu_seqlens[-1] + new_input_length
# Pad present for next iter attention
new_input_length = input_length + 1
past = torch.nn.functional.pad(
present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1)
)
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
# Prefill
if stopping_criteria.current_tokens == 1:
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids_tensor[1:input_length].unsqueeze(1)
@ -433,52 +523,18 @@ class FlashCausalLM(Model):
generations.append(generation)
cumulative_length += input_length
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generations, None
# Update values
batch.input_ids[i] = next_token_id
batch.position_ids[i] = input_length
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.all_input_ids[i] = all_input_ids
batch.all_input_ids_tensor[i] = all_input_ids_tensor
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
batch.past_key_values[i] = past
# Cumulative sum
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Create final next batch tensors
next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32
)
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1)
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
else:
next_batch_input_ids = next_batch_input_ids[0].view(1)
next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashCausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
position_ids=next_batch_position_ids,
cu_seqlens=next_batch_cu_seqlens,
max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
all_input_ids=next_batch_all_input_ids,
all_input_ids_tensor=next_batch_all_input_ids_tensor,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
)
return generations, next_batch
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None

View File

@ -96,11 +96,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
stopping_criterias = []
offsets = []
token_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
offsets.append(None)
@ -115,7 +117,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding_right_offset, stopping_criteria.max_new_tokens
)
# Tokenize batch
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
@ -138,23 +139,23 @@ class GalacticaCausalLMBatch(CausalLMBatch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_input_length=max_input_length,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
)

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
class Seq2SeqLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Encoder values
input_ids: torch.Tensor
@ -32,6 +33,9 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask: Optional[torch.Tensor]
encoder_last_hidden_state: Optional[torch.Tensor]
# All tokens
all_decoder_input_ids: List[torch.Tensor]
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
past_key_values: Optional[List[Tuple]]
@ -46,7 +50,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int
max_input_length: int
max_decoder_input_length: int
padding_right_offset: int
@ -54,9 +57,7 @@ class Seq2SeqLMBatch(Batch):
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
id=self.batch_id, requests=self.requests, size=len(self)
)
@classmethod
@ -71,18 +72,17 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
decoder_input_ids = []
decoder_input_lengths = []
offsets = []
token_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
for i, r in enumerate(pb.requests):
inputs.append(r.inputs)
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
offsets.append(None)
token_offsets.append(None)
@ -109,15 +109,22 @@ class Seq2SeqLMBatch(Batch):
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
# Decoder sequence only contains the bos_token
decoder_input_ids = (
torch.tensor(tokenizer.bos_token_id, device=device)
.repeat(len(pb.requests))
.view(-1, 1)
)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=list(all_decoder_input_ids),
decoder_attention_mask=None,
encoder_last_hidden_state=None,
past_key_values=None,
@ -127,12 +134,96 @@ class Seq2SeqLMBatch(Batch):
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=len(pb.requests),
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
)
@tracer.start_as_current_span("filter")
def filter(
self, requests: List[generate_pb2.Request]
) -> Optional["Seq2SeqLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
decoder_input_lengths = []
offsets = []
token_offsets = []
all_decoder_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_input_length = 0
max_decoder_input_length = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
request_decoder_input_length = self.decoder_input_lengths[idx]
decoder_input_lengths.append(request_decoder_input_length)
max_decoder_input_length = max(
max_decoder_input_length, request_decoder_input_length
)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
decoder_input_ids = self.decoder_input_ids[keep_indices]
attention_mask = self.attention_mask[keep_indices]
if self.decoder_attention_mask is not None:
decoder_attention_mask = self.decoder_attention_mask[keep_indices]
else:
decoder_attention_mask = None
encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]
past_key_values = [
[t[keep_indices] for t in layer] for layer in self.past_key_values
]
return Seq2SeqLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=self.padding_right_offset,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
@ -144,7 +235,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
max_decoder_input_length = max(
max_decoder_input_length, batch.max_decoder_input_length
@ -153,6 +244,8 @@ class Seq2SeqLMBatch(Batch):
# Batch attributes
requests = []
requests_idx_mapping = {}
all_decoder_input_ids = []
input_lengths = []
decoder_input_lengths = []
offsets = []
@ -174,6 +267,7 @@ class Seq2SeqLMBatch(Batch):
for i, batch in enumerate(batches):
# Extend all list attributes
requests.extend(batch.requests)
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
offsets.extend(batch.offsets)
@ -181,8 +275,15 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + batch.size
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.encoder_last_hidden_state is None:
@ -201,12 +302,10 @@ class Seq2SeqLMBatch(Batch):
# Create padded tensor
if decoder_input_ids is None:
decoder_input_ids = batch.decoder_input_ids.new_zeros(
(total_batch_size, max_decoder_input_length),
(total_batch_size, 1),
)
# Copy to correct indices
decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
# Create padded tensor
if decoder_attention_mask is None:
@ -302,14 +401,16 @@ class Seq2SeqLMBatch(Batch):
start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :]
start_index += batch.size
start_index += len(batch)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values,
@ -319,7 +420,6 @@ class Seq2SeqLMBatch(Batch):
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
@ -413,46 +513,25 @@ class Seq2SeqLM(Model):
else:
decoder_attention_mask = None
# check if first forward or not
if batch.past_key_values is not None:
# Only take the last token
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
else:
decoder_input_ids = batch.decoder_input_ids
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if batch.encoder_last_hidden_state is not None:
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
else:
encoder_last_hidden_state = batch.encoder_last_hidden_state
encoder_last_hidden_state = None
logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
decoder_input_ids,
batch.decoder_input_ids,
decoder_attention_mask,
encoder_last_hidden_state,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
next_batch_max_decoder_input_length = 0
# Finished requests
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
@ -464,7 +543,7 @@ class Seq2SeqLM(Model):
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.decoder_input_ids,
batch.all_decoder_input_ids,
)
# For each member of the batch
@ -477,22 +556,24 @@ class Seq2SeqLM(Model):
logits,
next_token_chooser,
stopping_criteria,
decoder_input_ids,
all_decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
all_decoder_input_ids.view(1, -1), logits
)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, next_token_id.squeeze(1)]
)
new_decoder_input_length = decoder_input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token(
decoder_input_ids, offset, token_offset
all_decoder_input_ids, offset, token_offset
)
# Evaluate stopping criteria
@ -501,7 +582,7 @@ class Seq2SeqLM(Model):
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text = self.decode(decoder_input_ids[-decoder_input_length:])
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
@ -515,19 +596,7 @@ class Seq2SeqLM(Model):
else:
# Keep request in the batch
generated_text = None
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
stopped = False
# Prefill
if stopping_criteria.current_tokens == 1:
@ -551,69 +620,29 @@ class Seq2SeqLM(Model):
generations.append(generation)
# Update values
batch.decoder_input_ids[i] = next_token_id
batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length
batch.decoder_input_lengths[i] = new_decoder_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, input_length)
batch.max_decoder_input_length = max(
batch.max_decoder_input_length, new_decoder_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
if stopped:
return generations, None
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
next_batch_keep_indices
]
else:
next_batch_decoder_attention_mask = None
next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
next_batch_keep_indices
]
next_batch_past_key_values = [
[t[next_batch_keep_indices] for t in layer] for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_decoder_attention_mask = batch.decoder_attention_mask
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# We don't need input_ids after the prefill forward
batch.input_ids = None
batch.encoder_last_hidden_state = encoder_last_hidden_state
batch.past_key_values = past
# Update decoder_attention_mask as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None:
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
if batch.decoder_attention_mask is not None:
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
next_batch = Seq2SeqLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=None,
attention_mask=next_batch_attention_mask,
decoder_input_ids=next_batch_decoder_input_ids,
decoder_attention_mask=next_batch_decoder_attention_mask,
encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length,
padding_right_offset=batch.padding_right_offset - 1,
)
return generations, next_batch
return generations, batch

View File

@ -25,6 +25,10 @@ class Batch(ABC):
) -> "Batch":
raise NotImplementedError
@abstractmethod
def filter(self, requests: List[generate_pb2.Request]) -> "Batch":
raise NotImplementedError
@classmethod
@abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch":

View File

@ -60,7 +60,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
batch = batch.filter(batch_pb.requests)
if batch is not None:
batches.append(batch)
if len(batches) == 0:
raise ValueError("All batches are empty")
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches)