feat(router): drop requests when client closes the channel (#202)

This commit is contained in:
OlivierDehaene 2023-04-20 11:07:40 +02:00 committed by GitHub
parent b6ee0ec7b0
commit 709d8936f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 821 additions and 571 deletions

View File

@ -66,7 +66,7 @@ jobs:
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
registry: registry.internal.huggingface.tech registry: registry.internal.huggingface.tech
- name: Login to Azure Container Registry - name: Login to Azure Container Registry
# if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@v2.1.0 uses: docker/login-action@v2.1.0
with: with:
username: ${{ secrets.AZURE_DOCKER_USERNAME }} username: ${{ secrets.AZURE_DOCKER_USERNAME }}

118
Cargo.lock generated
View File

@ -42,42 +42,51 @@ dependencies = [
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.2.6" version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" checksum = "9e579a7752471abc2a8268df8b20005e3eadd975f585398f17efcfd8d4927371"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"anstyle-parse", "anstyle-parse",
"anstyle-query",
"anstyle-wincon", "anstyle-wincon",
"concolor-override", "colorchoice",
"concolor-query",
"is-terminal", "is-terminal",
"utf8parse", "utf8parse",
] ]
[[package]] [[package]]
name = "anstyle" name = "anstyle"
version = "0.3.5" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d"
[[package]] [[package]]
name = "anstyle-parse" name = "anstyle-parse"
version = "0.1.1" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" checksum = "e765fd216e48e067936442276d1d57399e37bce53c264d6fefbe298080cb57ee"
dependencies = [ dependencies = [
"utf8parse", "utf8parse",
] ]
[[package]] [[package]]
name = "anstyle-wincon" name = "anstyle-query"
version = "0.2.0" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b"
dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "anstyle-wincon"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bcd8291a340dd8ac70e18878bc4501dd7b4ff970cfa21c207d36ece51ea88fd"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"windows-sys 0.45.0", "windows-sys 0.48.0",
] ]
[[package]] [[package]]
@ -105,7 +114,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -116,7 +125,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -127,9 +136,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.6.13" version = "0.6.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6539e4565c365448d483967c6dee3eaecb8e87679a17806a831e82b05b903c18" checksum = "3b32c5ea3aabaf4deb5f5ced2d688ec0844c881c9e6c696a8b769a05fc691e62"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
@ -310,9 +319,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.2.1" version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" checksum = "9b802d85aaf3a1cdb02b224ba472ebdea62014fccfcb269b95a4d76443b5ee5a"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -321,9 +330,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.2.1" version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" checksum = "14a1a858f532119338887a4b8e1af9c60de8249cd7bafd68036a489e261e37b6"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -341,7 +350,7 @@ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -351,19 +360,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1"
[[package]] [[package]]
name = "concolor-override" name = "colorchoice"
version = "1.0.0" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "concolor-query"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf"
dependencies = [
"windows-sys 0.45.0",
]
[[package]] [[package]]
name = "console" name = "console"
@ -794,7 +794,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -868,9 +868,9 @@ dependencies = [
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.16" version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d" checksum = "17f8a914c2987b688368b5138aa05321db91f4090cf26118185672ad588bce21"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
@ -966,9 +966,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "0.14.25" version = "0.14.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899" checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
@ -1364,7 +1364,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -1517,7 +1517,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -1787,9 +1787,9 @@ dependencies = [
[[package]] [[package]]
name = "prost" name = "prost"
version = "0.11.8" version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537" checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [ dependencies = [
"bytes", "bytes",
"prost-derive", "prost-derive",
@ -1797,9 +1797,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-build" name = "prost-build"
version = "0.11.8" version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12" checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270"
dependencies = [ dependencies = [
"bytes", "bytes",
"heck", "heck",
@ -1819,9 +1819,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-derive" name = "prost-derive"
version = "0.11.8" version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b" checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"itertools 0.10.5", "itertools 0.10.5",
@ -1832,9 +1832,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-types" name = "prost-types"
version = "0.11.8" version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88" checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
dependencies = [ dependencies = [
"prost", "prost",
] ]
@ -2153,14 +2153,14 @@ checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.95" version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -2330,9 +2330,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.14" version = "2.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcf316d5356ed6847742d036f8a39c3b8435cac10bd528a4bd461928a6ab34d5" checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2450,7 +2450,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -2578,7 +2578,7 @@ checksum = "61a573bdc87985e9d6ddeed1b3d864e8a302c847e40d647746df2f1de209d1ce"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -2928,9 +2928,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]] [[package]]
name = "utoipa" name = "utoipa"
version = "3.2.1" version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24e7ee17c9ef094b86e1e04170d90765bd76cb381921dacb4d3e175a267bdae6" checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98"
dependencies = [ dependencies = [
"indexmap", "indexmap",
"serde", "serde",
@ -2940,14 +2940,14 @@ dependencies = [
[[package]] [[package]]
name = "utoipa-gen" name = "utoipa-gen"
version = "3.2.1" version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df6f458e5abc811d44aca28455efc4163fb7565a7af2aa32d17611f3d1d9794d" checksum = "7ea8ac818da7e746a63285594cce8a96f5e00ee31994e655bd827569cb8b137b"
dependencies = [ dependencies = [
"proc-macro-error", "proc-macro-error",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]

View File

@ -3,6 +3,7 @@ use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken}; use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream; use flume::r#async::RecvStream;
use flume::SendError;
use futures::future::try_join_all; use futures::future::try_join_all;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
@ -11,7 +12,7 @@ use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::{Notify, Semaphore, TryAcquireError}; use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
@ -73,9 +74,14 @@ impl Infer {
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<
(
OwnedSemaphorePermit,
RecvStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let permit = self let permit = self
.clone() .clone()
.limit_concurrent_requests .limit_concurrent_requests
@ -104,7 +110,6 @@ impl Infer {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit,
}); });
// Notify the background task that we have a new entry in the queue that needs // Notify the background task that we have a new entry in the queue that needs
@ -112,7 +117,7 @@ impl Infer {
self.shared.batching_task.notify_one(); self.shared.batching_task.notify_one();
// Return stream // Return stream
Ok(response_rx.into_stream()) Ok((permit, response_rx.into_stream()))
} }
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
@ -121,8 +126,8 @@ impl Infer {
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<InferResponse, InferError> { ) -> Result<InferResponse, InferError> {
// Create stream // Create stream and keep semaphore permit as long as generate lives
let mut stream = self.generate_stream(request).await?; let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
@ -276,12 +281,10 @@ async fn batching_task(
.next_batch(min_size, max_batch_size - batch_size as usize) .next_batch(min_size, max_batch_size - batch_size as usize)
.await .await
{ {
let new_batch_size = new_batch.size;
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting // Create a new span to add the info that this entry is waiting
// because a new batch is being computed // because a new batch is being computed
let entry_waiting_span = let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationships // Add relationships
span.follows_from(&entry_waiting_span); span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span); entry_waiting_span.follows_from(&span);
@ -308,8 +311,7 @@ async fn batching_task(
info_span!(parent: None, "batch", batch_size = next_batch_size); info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = let entry_batch_span = info_span!(parent: &entry.span, "infer");
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationships // Add relationships
next_batch_span.follows_from(&entry_batch_span); next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
@ -339,7 +341,23 @@ async fn prefill(
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
send_generations(generations, entries); filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch next_batch
@ -361,17 +379,37 @@ async fn decode(
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> { ) -> Option<Batch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
send_generations(generations, entries); filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None None
@ -379,6 +417,86 @@ async fn decode(
} }
} }
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
batch.requests.retain(|r| entries.contains_key(&r.id));
let size = batch.requests.len();
if size == 0 {
return None;
}
batch.size = size as u32;
Some(batch)
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))?;
}
Ok(stopped)
}
/// Send errors to Infer for all `entries` /// Send errors to Infer for all `entries`
#[instrument(skip_all)] #[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) { fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
@ -397,65 +515,6 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
}); });
} }
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
#[instrument(skip_all)]
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))
.unwrap_or(());
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let entry = entries
.remove(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))
.unwrap_or(());
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))
.unwrap_or(());
}
});
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message

View File

@ -3,8 +3,9 @@ use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::{oneshot, OwnedSemaphorePermit}; use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, instrument, Span};
@ -23,8 +24,6 @@ pub(crate) struct Entry {
pub queue_time: Instant, pub queue_time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
pub batch_time: Option<Instant>, pub batch_time: Option<Instant>,
/// Permit
pub _permit: OwnedSemaphorePermit,
} }
/// Request Queue /// Request Queue
@ -104,7 +103,7 @@ async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
#[derive(Debug)] #[derive(Debug)]
struct State { struct State {
/// Queue entries organized in a Vec /// Queue entries organized in a Vec
entries: Vec<(u64, Entry)>, entries: VecDeque<(u64, Entry)>,
/// Id of the next entry /// Id of the next entry
next_id: u64, next_id: u64,
@ -116,7 +115,7 @@ struct State {
impl State { impl State {
fn new() -> Self { fn new() -> Self {
Self { Self {
entries: Vec::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
} }
@ -129,7 +128,7 @@ impl State {
entry.temp_span = Some(queue_span); entry.temp_span = Some(queue_span);
// Push entry in the queue // Push entry in the queue
self.entries.push((self.next_id, entry)); self.entries.push_back((self.next_id, entry));
self.next_id += 1; self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0); metrics::increment_gauge!("tgi_queue_size", 1.0);
} }
@ -147,51 +146,70 @@ impl State {
} }
} }
let next_batch_size = min(self.entries.len(), max_size); let max_batch_size = min(self.entries.len(), max_size);
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current()); next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(next_batch_size); let mut batch_requests = Vec::with_capacity(max_batch_size);
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
// Drain next_batch_size entries // Iterate on buffer
self.entries while let Some((id, mut entry)) = self.entries.pop_front() {
.drain(..next_batch_size) // Filter entries where the response receiver was dropped (== entries where the request
.for_each(|(id, mut entry)| { // was dropped by the client)
// Create a new span to link the batch back to this entry if entry.response_tx.is_disconnected() {
let entry_batch_span = metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); continue;
// Add relationships }
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
batch_requests.push(Request { // Create a new span to link the batch back to this entry
id, let entry_batch_span = info_span!(parent: &entry.span, "infer");
inputs: entry.request.inputs.clone(), // Add relationships
truncate: entry.request.truncate, next_batch_span.follows_from(&entry_batch_span);
parameters: Some(entry.request.parameters.clone()), entry_batch_span.follows_from(&next_batch_span);
stopping_parameters: Some(entry.request.stopping_parameters.clone()), // Update entry
}); entry.temp_span = Some(entry_batch_span);
// Set batch_time
entry.batch_time = Some(Instant::now()); batch_requests.push(Request {
// Insert in batch_entries IntMap id,
batch_entries.insert(id, entry); inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
}); });
// Set batch_time
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
if batch_requests.len() == max_batch_size {
// We have enough requests in the batch
break;
}
}
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
// Maybe all entries were dropped because their channel were closed
if batch_requests.is_empty() {
return None;
}
// Final batch size once we dropped entries
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
let batch = Batch { let batch = Batch {
id: self.next_batch_id, id: self.next_batch_id,
requests: batch_requests, requests: batch_requests,
size: next_batch_size as u32, size,
}; };
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
metrics::histogram!("tgi_batch_next_size", batch.size as f64); metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span)) Some((batch_entries, batch, next_batch_span))
} }
@ -213,17 +231,16 @@ enum QueueCommand {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::Semaphore;
use tracing::info_span; use tracing::info_span;
fn default_entry() -> Entry { fn default_entry() -> (
let semaphore = Arc::new(Semaphore::new(1)); Entry,
let (response_tx, _) = flume::unbounded(); flume::Receiver<Result<InferStreamResponse, InferError>>,
let permit = semaphore.try_acquire_owned().unwrap(); ) {
let (response_tx, receiver_tx) = flume::unbounded();
Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: "".to_string(), inputs: "".to_string(),
truncate: 0, truncate: 0,
@ -248,14 +265,14 @@ mod tests {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit, };
} (entry, receiver_tx)
} }
#[test] #[test]
fn test_append() { fn test_append() {
let mut state = State::new(); let mut state = State::new();
let entry = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
assert_eq!(state.entries.len(), 0); assert_eq!(state.entries.len(), 0);
@ -264,7 +281,7 @@ mod tests {
assert_eq!(state.next_id, 1); assert_eq!(state.next_id, 1);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0); let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 0); assert_eq!(id, 0);
} }
@ -279,8 +296,10 @@ mod tests {
#[test] #[test]
fn test_next_batch_min_size() { fn test_next_batch_min_size() {
let mut state = State::new(); let mut state = State::new();
state.append(default_entry()); let (entry1, _guard1) = default_entry();
state.append(default_entry()); let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
@ -295,21 +314,24 @@ mod tests {
assert_eq!(state.entries.len(), 0); assert_eq!(state.entries.len(), 0);
assert_eq!(state.next_batch_id, 1); assert_eq!(state.next_batch_id, 1);
state.append(default_entry()); let (entry3, _guard3) = default_entry();
state.append(entry3);
assert!(state.next_batch(Some(2), 2).is_none()); assert!(state.next_batch(Some(2), 2).is_none());
assert_eq!(state.next_id, 3); assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0); let (id, _) = state.entries.remove(0).unwrap();
assert_eq!(id, 2); assert_eq!(id, 2);
} }
#[test] #[test]
fn test_next_batch_max_size() { fn test_next_batch_max_size() {
let mut state = State::new(); let mut state = State::new();
state.append(default_entry()); let (entry1, _guard1) = default_entry();
state.append(default_entry()); let (entry2, _guard2) = default_entry();
state.append(entry1);
state.append(entry2);
let (entries, batch, _) = state.next_batch(None, 1).unwrap(); let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
@ -321,7 +343,8 @@ mod tests {
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
assert_eq!(state.next_batch_id, 1); assert_eq!(state.next_batch_id, 1);
state.append(default_entry()); let (entry3, _guard3) = default_entry();
state.append(entry3);
let (entries, batch, _) = state.next_batch(None, 3).unwrap(); let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
@ -338,7 +361,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(); let queue = Queue::new();
queue.append(default_entry()); let (entry, _guard) = default_entry();
queue.append(entry);
} }
#[tokio::test] #[tokio::test]
@ -352,8 +376,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(); let queue = Queue::new();
queue.append(default_entry()); let (entry1, _guard1) = default_entry();
queue.append(default_entry()); let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
@ -364,7 +390,8 @@ mod tests {
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2); assert_eq!(batch.size, 2);
queue.append(default_entry()); let (entry3, _guard3) = default_entry();
queue.append(entry3);
assert!(queue.next_batch(Some(2), 2).await.is_none()); assert!(queue.next_batch(Some(2), 2).await.is_none());
} }
@ -372,8 +399,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(); let queue = Queue::new();
queue.append(default_entry()); let (entry1, _guard1) = default_entry();
queue.append(default_entry()); let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
@ -381,7 +410,8 @@ mod tests {
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
assert_eq!(batch.size, 1); assert_eq!(batch.size, 1);
queue.append(default_entry()); let (entry3, _guard3) = default_entry();
queue.append(entry3);
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
@ -390,4 +420,13 @@ mod tests {
assert_eq!(batch.id, 1); assert_eq!(batch.id, 1);
assert_eq!(batch.size, 2); assert_eq!(batch.size, 2);
} }
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new();
let (entry, _) = default_entry();
queue.append(entry);
assert!(queue.next_batch(None, 1).await.is_none());
}
} }

View File

@ -367,7 +367,8 @@ async fn generate_stream(
let best_of = req.0.parameters.best_of.unwrap_or(1); let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 { if best_of == 1 {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => { // Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
match response { match response {

View File

@ -1,6 +1,9 @@
include Makefile-transformers include Makefile-transformers
include Makefile-flash-att include Makefile-flash-att
unit-tests:
python -m pytest tests
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir

View File

@ -45,8 +45,9 @@ def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
@pytest.fixture @pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer): def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request) req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request req_1 = default_pb_request
req_1.id = 1 req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5 req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
@ -70,12 +71,17 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.past_key_values is None assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0]) assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1] assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
@ -97,7 +103,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1 assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11 assert len(next_batch.attention_mask[0]) == 11
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264) assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
@ -106,7 +112,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert torch.all(next_batch.attention_mask[0][:2] == 1) assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0) assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1) assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 10264 assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2] assert next_batch.input_lengths == [2]
@ -170,6 +176,8 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range( for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
@ -269,6 +277,8 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
for _ in range( for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
@ -290,6 +300,8 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[1]])
for _ in range( for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_bloom_batch.stopping_criterias[0].max_new_tokens - default_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -44,11 +44,12 @@ def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
@pytest.fixture @pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request) req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request req_1 = default_pb_request
req_1.id = 1 req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5 req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
@ -67,12 +68,17 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.past_key_values is None assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0]) assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)
assert batch.input_lengths == [1] assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
@ -93,7 +99,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1 assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11 assert len(next_batch.attention_mask[0]) == 11
assert next_batch.all_input_ids[0][-1] == 13 assert next_batch.all_input_ids[0][-1] == 13
@ -103,7 +109,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert torch.all(next_batch.attention_mask[0][0:2] == 1) assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0) assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1) assert next_batch.input_ids.shape == (len(next_batch), 1)
assert next_batch.input_ids[0, 0] == 13 assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2] assert next_batch.input_lengths == [2]
@ -168,6 +174,8 @@ def test_causal_lm_generate_token_completion_multi(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0]])
for _ in range( for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
@ -266,6 +274,8 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
for _ in range( for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
@ -285,6 +295,8 @@ def test_batch_concatenate(
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[1]])
for _ in range( for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens - default_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -49,8 +49,9 @@ def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
@pytest.fixture @pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer): def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request) req_0 = copy(default_pb_request)
req_0.id = 1
req_1 = default_pb_request req_1 = default_pb_request
req_1.id = 1 req_1.id = 2
req_1.stopping_parameters.max_new_tokens = 5 req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
@ -72,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert torch.all(batch.attention_mask[0][-2:] == 1) assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0) assert torch.all(batch.attention_mask[0][:-2] == 0)
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1) assert len(batch.decoder_input_ids) == default_pb_batch.size
assert batch.decoder_attention_mask is None assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None assert batch.encoder_last_hidden_state is None
@ -81,8 +82,8 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
assert batch.input_lengths == [2] assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1] assert batch.decoder_input_lengths == [1]
assert batch.size == default_pb_batch.size assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)
assert batch.max_input_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0] assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
@ -117,9 +118,9 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
) )
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2) assert len(next_batch.decoder_input_ids) == len(next_batch)
assert next_batch.decoder_input_ids[0, 0] == 0 assert next_batch.all_decoder_input_ids[0][0] == 0
assert next_batch.decoder_input_ids[0, 1] == 259 assert next_batch.all_decoder_input_ids[0][1] == 259
assert next_batch.decoder_attention_mask is None assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512) assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)
@ -128,20 +129,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert next_batch.past_key_values is not None assert next_batch.past_key_values is not None
assert all( assert all(
[p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] [p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] [p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[ [
p[2].shape == (next_batch.size, 6, sequence_length, 64) p[2].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values for p in next_batch.past_key_values
] ]
) )
assert all( assert all(
[ [
p[3].shape == (next_batch.size, 6, sequence_length, 64) p[3].shape == (len(next_batch), 6, sequence_length, 64)
for p in next_batch.past_key_values for p in next_batch.past_key_values
] ]
) )
@ -189,6 +190,8 @@ def test_seq2seq_lm_generate_token_completion_multi(
) )
assert generations[1].generated_text.generated_tokens == 5 assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
@ -223,7 +226,8 @@ def test_batch_concatenate(
assert torch.equal( assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0] next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
) )
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0) assert next_batch.all_decoder_input_ids[1][0] == 0
assert next_batch.all_decoder_input_ids[2][0] == 0
assert torch.equal( assert torch.equal(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
) )
@ -258,16 +262,16 @@ def test_batch_concatenate(
assert next_batch.past_key_values is not None assert next_batch.past_key_values is not None
assert all( assert all(
[p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] [p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] [p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] [p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
) )
assert all( assert all(
[p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] [p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]
) )
for i, past in enumerate(next_batch.past_key_values): for i, past in enumerate(next_batch.past_key_values):
@ -306,6 +310,8 @@ def test_batch_concatenate(
) )
assert generations[2].generated_text.generated_tokens == 5 assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
@ -314,6 +320,8 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1]])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
class CausalLMBatch(Batch): class CausalLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: torch.Tensor
@ -42,7 +43,6 @@ class CausalLMBatch(Batch):
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Metadata used for padding # Metadata used for padding
size: int
max_input_length: int max_input_length: int
padding_right_offset: int padding_right_offset: int
@ -53,7 +53,7 @@ class CausalLMBatch(Batch):
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
size=self.size, size=len(self),
) )
@classmethod @classmethod
@ -68,11 +68,13 @@ class CausalLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
offsets = [] offsets = []
token_offsets = [] token_offsets = []
requests_idx_mapping = {}
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
offsets.append(None) offsets.append(None)
token_offsets.append(None) token_offsets.append(None)
@ -108,26 +110,91 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
offsets=offsets, offsets=offsets,
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
attention_mask = self.attention_mask[keep_indices]
position_ids = self.position_ids[keep_indices]
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_key_values = [
[t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
for layer in self.past_key_values
]
return CausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
padding_right_offset=self.padding_right_offset,
keys_head_dim_last=self.keys_head_dim_last,
)
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
@ -136,12 +203,13 @@ class CausalLMBatch(Batch):
max_input_length = 0 max_input_length = 0
padding_right_offset = 0 padding_right_offset = 0
for batch in batches: for batch in batches:
total_batch_size += batch.size total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length) max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset) padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {}
input_lengths = [] input_lengths = []
offsets = [] offsets = []
token_offsets = [] token_offsets = []
@ -167,8 +235,15 @@ class CausalLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch # Slicing end index for this batch
end_index = start_index + batch.size end_index = start_index + len(batch)
# We only concatenate batches that did at least one step # We only concatenate batches that did at least one step
if batch.past_key_values is None: if batch.past_key_values is None:
@ -216,8 +291,8 @@ class CausalLMBatch(Batch):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:]) past_keys = past_keys.view(len(batch), -1, *past_keys.shape[-2:])
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:]) past_values = past_values.view(len(batch), -1, *past_values.shape[-2:])
_, num_heads, padded_sequence_length, head_dim = past_values.shape _, num_heads, padded_sequence_length, head_dim = past_values.shape
@ -265,11 +340,12 @@ class CausalLMBatch(Batch):
start_index:end_index, :, -(batch.max_input_length - 1) :, : start_index:end_index, :, -(batch.max_input_length - 1) :, :
] = past_values[:, :, -(batch.max_input_length - 1) :, :] ] = past_values[:, :, -(batch.max_input_length - 1) :, :]
start_index += batch.size start_index += len(batch)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -280,7 +356,6 @@ class CausalLMBatch(Batch):
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
@ -364,22 +439,9 @@ class CausalLM(Model):
batch.past_key_values, batch.past_key_values,
) )
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_input_ids = []
next_batch_all_input_ids = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -443,16 +505,7 @@ class CausalLM(Model):
else: else:
# Keep request in the batch # Keep request in the batch
generated_text = None generated_text = None
next_batch_keep_indices.append(i) stopped = False
next_batch_input_ids.append(next_token_id)
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max(
next_batch_max_input_length, new_input_length
)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
@ -484,62 +537,30 @@ class CausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if stopped:
return generations, None return generations, None
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) # Slice unused values from prefill
# If we finished at least one generation, we need to evict the indices of the generations that finished batch.input_ids = batch.input_ids[:, :1]
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [
[
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
for t in layer
]
for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_position_ids = batch.position_ids
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
next_batch_attention_mask[:, -batch.padding_right_offset] = 1 batch.attention_mask[:, -batch.padding_right_offset] = 1
# Decrease right offset
batch.padding_right_offset -= 1
# Update position_ids # Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1 batch.position_ids = batch.position_ids[:, -1:] + 1
next_batch = CausalLMBatch( # Update past key values
batch_id=batch.batch_id, batch.past_key_values = past
requests=next_batch_requests,
input_ids=next_batch_input_ids, return generations, batch
attention_mask=next_batch_attention_mask,
position_ids=next_batch_position_ids,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last,
)
return generations, next_batch

View File

@ -6,7 +6,7 @@ from torch.nn import functional as F
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Union from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -29,14 +29,16 @@ tracer = trace.get_tracer(__name__)
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: List[torch.Tensor]
position_ids: torch.Tensor position_ids: List[torch.Tensor]
# cumulative sequence lengths # cumulative sequence lengths
cu_seqlens: torch.Tensor cu_seqlens: List[int]
max_seqlen: int max_seqlen: int
past_key_values: Optional[torch.Tensor] past_key_values: Optional[List[torch.Tensor]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
@ -62,7 +64,7 @@ class FlashCausalLMBatch(Batch):
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "FlashCausalLMBatch":
input_ids = [] input_ids = []
position_ids = [] position_ids = []
cu_seqlens = [0] cu_seqlens = [0]
@ -73,6 +75,7 @@ class FlashCausalLMBatch(Batch):
token_offsets = [] token_offsets = []
all_input_ids = [] all_input_ids = []
all_input_ids_tensor = [] all_input_ids_tensor = []
requests_idx_mapping = {}
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -81,13 +84,18 @@ class FlashCausalLMBatch(Batch):
cumulative_length = 0 cumulative_length = 0
# Parse batch # Parse batch
for r in pb.requests: for i, r in enumerate(pb.requests):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenizer( tokenized_input = tokenizer(
r.inputs, truncation=True, max_length=r.truncate r.inputs, truncation=True, max_length=r.truncate
)["input_ids"] )["input_ids"]
input_length = len(tokenized_input) input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
offsets.append(None) offsets.append(None)
token_offsets.append(None) token_offsets.append(None)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
@ -96,7 +104,9 @@ class FlashCausalLMBatch(Batch):
input_ids.append(tokenized_input) input_ids.append(tokenized_input)
# Position ids # Position ids
position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) position_ids.append(
torch.arange(0, input_length, dtype=torch.int32, device=device)
)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length) cu_seqlens.append(cumulative_length + input_length)
@ -113,13 +123,10 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_length += input_length cumulative_length += input_length
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
@ -134,60 +141,141 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
) )
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(requests) == len(self):
return self
# Cumulative length
cumulative_length = 0
# New values after filtering
requests_idx_mapping = {}
input_ids = []
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
# Get length
request_input_length = self.input_lengths[idx]
input_ids.append(self.input_ids[idx])
position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length)
past_key_values.append(self.past_key_values[idx])
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
cumulative_length += request_input_length
return FlashCausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes # Batch attributes
requests = [] requests = []
input_lengths = [] requests_idx_mapping = {}
offsets = []
token_offsets = []
all_input_ids = []
all_input_ids_tensor = []
next_token_choosers = []
stopping_criterias = []
# Batch tensors
input_ids = [] input_ids = []
position_ids = [] position_ids = []
cu_seqlens = [torch.tensor([0], dtype=torch.int32)] cu_seqlens = [0]
max_seqlen = 0 max_seqlen = 0
past_key_values = [] past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
# Cumulative length # Cumulative length
cumulative_length = torch.tensor(0) cumulative_batch_size = 0
cumulative_length = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
input_ids.extend(batch.input_ids)
position_ids.extend(batch.position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
past_key_values.extend(batch.past_key_values)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets) offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets) token_offsets.extend(batch.token_offsets)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length)
input_ids.append(batch.input_ids)
position_ids.append(batch.position_ids)
past_key_values.append(batch.past_key_values)
max_seqlen = max(max_seqlen, batch.max_seqlen)
# Update # Update
cumulative_length += batch.cu_seqlens[-1] cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
# Concat on dim=1 as first dim represents the model layers
past_key_values = torch.concat(past_key_values, dim=1)
cu_seqlens = torch.concat(cu_seqlens)
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
@ -269,38 +357,49 @@ class FlashCausalLM(Model):
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# Better to send to device here to avoid device issues in concatenate # Shortcut when batch_size == 1
position_ids = batch.position_ids.to(self.device, non_blocking=True) if len(batch) == 1:
cu_seqlens = batch.cu_seqlens.to(self.device) input_ids = batch.input_ids[0].view(-1)
past_key_values = (
batch.past_key_values[0] if batch.past_key_values is not None else None
)
else:
# Concatenate tensors
input_ids = torch.cat(batch.input_ids).view(-1)
past_key_values = (
torch.cat(batch.past_key_values, dim=1)
if batch.past_key_values is not None
else None
)
# Concatenate when prefill, torch.tensor when decode
position_ids = (
torch.tensor(batch.position_ids, device=self.device)
if batch.past_key_values is not None
else torch.cat(batch.position_ids)
)
cu_seqlens = torch.tensor(
batch.cu_seqlens, device=self.device, dtype=torch.int32
)
out, present = self.forward( out, present = self.forward(
batch.input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
batch.max_seqlen, batch.max_seqlen,
batch.past_key_values, past_key_values,
) )
# List of indices to cache # Initialize past_key_values in prefill
next_batch_keep_indices = [] if batch.past_key_values is None:
batch.past_key_values = [None] * len(batch)
# New values for next forward
next_batch_input_ids = []
next_batch_position_ids = []
next_batch_cu_seqlens = [0]
next_batch_max_seqlen = 0
next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -329,7 +428,8 @@ class FlashCausalLM(Model):
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
if batch.past_key_values is None: prefill = stopping_criteria.current_tokens == 0
if prefill:
# Prefill mode # Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size] # out is of shape [cumulative_sequence_lengths, vocab_size]
logits = out[start_index:end_index] logits = out[start_index:end_index]
@ -348,7 +448,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id_item) all_input_ids.append(next_token_id_item)
all_input_ids_tensor[input_length] = next_token_id_item all_input_ids_tensor[input_length] = next_token_id_item
new_input_length = input_length + 1
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id_item] next_token_logprob = logprobs[-1, next_token_id_item]
@ -378,32 +477,23 @@ class FlashCausalLM(Model):
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed output_text, stopping_criteria.current_tokens, reason, seed
) )
# CAUTION: generation will be stopped so no need to pad
# This will make the next forward crash if the request does not get filtered
new_input_length = input_length
past = present[:, start_index:end_index]
else: else:
# Keep request in the batch stopped = False
next_batch_keep_indices.append(i)
generated_text = None generated_text = None
# Get sequence present # Pad present for next iter attention
seq_present = present[:, start_index:end_index] new_input_length = input_length + 1
# Pad it for next iter attention past = torch.nn.functional.pad(
past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1)
next_batch_past_key_values.append(past)
next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(input_length)
# Cumulative sum
next_batch_cu_seqlens.append(
next_batch_cu_seqlens[-1] + new_input_length
) )
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if prefill:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather( prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids_tensor[1:input_length].unsqueeze(1) 1, all_input_ids_tensor[1:input_length].unsqueeze(1)
@ -433,52 +523,18 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
cumulative_length += input_length cumulative_length += input_length
# We finished all generations in the batch; there is no next batch # Update values
if not next_batch_keep_indices: batch.input_ids[i] = next_token_id
return generations, None batch.position_ids[i] = input_length
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.all_input_ids[i] = all_input_ids
batch.all_input_ids_tensor[i] = all_input_ids_tensor
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
batch.past_key_values[i] = past
# Cumulative sum
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
# If we finished at least one generation, we need to evict the indices of the generations that finished # No need to return a batch if we know that all requests stopped
# from the values of the next batch return generations, batch if not stopped else None
if len(next_batch_keep_indices) != len(batch):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Create final next batch tensors
next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32
)
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1)
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
else:
next_batch_input_ids = next_batch_input_ids[0].view(1)
next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashCausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
position_ids=next_batch_position_ids,
cu_seqlens=next_batch_cu_seqlens,
max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
all_input_ids=next_batch_all_input_ids,
all_input_ids_tensor=next_batch_all_input_ids_tensor,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
)
return generations, next_batch

View File

@ -96,11 +96,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
stopping_criterias = [] stopping_criterias = []
offsets = [] offsets = []
token_offsets = [] token_offsets = []
requests_idx_mapping = {}
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
offsets.append(None) offsets.append(None)
@ -115,7 +117,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
# Tokenize batch
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
@ -138,23 +139,23 @@ class GalacticaCausalLMBatch(CausalLMBatch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=list(all_input_ids),
input_lengths=input_lengths, input_lengths=input_lengths.tolist(),
offsets=offsets, offsets=offsets,
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, max_input_length=max_input_length.item(),
max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -22,6 +22,7 @@ tracer = trace.get_tracer(__name__)
class Seq2SeqLMBatch(Batch): class Seq2SeqLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Encoder values # Encoder values
input_ids: torch.Tensor input_ids: torch.Tensor
@ -32,6 +33,9 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask: Optional[torch.Tensor] decoder_attention_mask: Optional[torch.Tensor]
encoder_last_hidden_state: Optional[torch.Tensor] encoder_last_hidden_state: Optional[torch.Tensor]
# All tokens
all_decoder_input_ids: List[torch.Tensor]
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
past_key_values: Optional[List[Tuple]] past_key_values: Optional[List[Tuple]]
@ -46,7 +50,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Metadata used for padding # Metadata used for padding
size: int
max_input_length: int max_input_length: int
max_decoder_input_length: int max_decoder_input_length: int
padding_right_offset: int padding_right_offset: int
@ -54,9 +57,7 @@ class Seq2SeqLMBatch(Batch):
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id, requests=self.requests, size=len(self)
requests=self.requests,
size=self.size,
) )
@classmethod @classmethod
@ -71,18 +72,17 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
decoder_input_ids = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] offsets = []
token_offsets = [] token_offsets = []
requests_idx_mapping = {}
# Parse batch # Parse batch
max_truncation = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for i, r in enumerate(pb.requests):
inputs.append(r.inputs) inputs.append(r.inputs)
# Decoder sequence only contains the bos_token requests_idx_mapping[r.id] = i
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
offsets.append(None) offsets.append(None)
token_offsets.append(None) token_offsets.append(None)
@ -109,15 +109,22 @@ class Seq2SeqLMBatch(Batch):
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Decoder sequence only contains the bos_token
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) decoder_input_ids = (
torch.tensor(tokenizer.bos_token_id, device=device)
.repeat(len(pb.requests))
.view(-1, 1)
)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=tokenized_inputs["input_ids"], input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"], attention_mask=tokenized_inputs["attention_mask"],
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=list(all_decoder_input_ids),
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_last_hidden_state=None, encoder_last_hidden_state=None,
past_key_values=None, past_key_values=None,
@ -127,12 +134,96 @@ class Seq2SeqLMBatch(Batch):
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=len(pb.requests),
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
@tracer.start_as_current_span("filter")
def filter(
self, requests: List[generate_pb2.Request]
) -> Optional["Seq2SeqLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
decoder_input_lengths = []
offsets = []
token_offsets = []
all_decoder_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_input_length = 0
max_decoder_input_length = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
request_decoder_input_length = self.decoder_input_lengths[idx]
decoder_input_lengths.append(request_decoder_input_length)
max_decoder_input_length = max(
max_decoder_input_length, request_decoder_input_length
)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
decoder_input_ids = self.decoder_input_ids[keep_indices]
attention_mask = self.attention_mask[keep_indices]
if self.decoder_attention_mask is not None:
decoder_attention_mask = self.decoder_attention_mask[keep_indices]
else:
decoder_attention_mask = None
encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices]
past_key_values = [
[t[keep_indices] for t in layer] for layer in self.past_key_values
]
return Seq2SeqLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=self.padding_right_offset,
)
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
@ -144,7 +235,7 @@ class Seq2SeqLMBatch(Batch):
max_decoder_input_length = 0 max_decoder_input_length = 0
padding_right_offset = 0 padding_right_offset = 0
for batch in batches: for batch in batches:
total_batch_size += batch.size total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length) max_input_length = max(max_input_length, batch.max_input_length)
max_decoder_input_length = max( max_decoder_input_length = max(
max_decoder_input_length, batch.max_decoder_input_length max_decoder_input_length, batch.max_decoder_input_length
@ -153,6 +244,8 @@ class Seq2SeqLMBatch(Batch):
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {}
all_decoder_input_ids = []
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
offsets = [] offsets = []
@ -174,6 +267,7 @@ class Seq2SeqLMBatch(Batch):
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
# Extend all list attributes # Extend all list attributes
requests.extend(batch.requests) requests.extend(batch.requests)
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths)
offsets.extend(batch.offsets) offsets.extend(batch.offsets)
@ -181,8 +275,15 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch # Slicing end index for this batch
end_index = start_index + batch.size end_index = start_index + len(batch)
# We only concatenate batches that did at least one step # We only concatenate batches that did at least one step
if batch.encoder_last_hidden_state is None: if batch.encoder_last_hidden_state is None:
@ -201,12 +302,10 @@ class Seq2SeqLMBatch(Batch):
# Create padded tensor # Create padded tensor
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = batch.decoder_input_ids.new_zeros( decoder_input_ids = batch.decoder_input_ids.new_zeros(
(total_batch_size, max_decoder_input_length), (total_batch_size, 1),
) )
# Copy to correct indices # Copy to correct indices
decoder_input_ids[ decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
# Create padded tensor # Create padded tensor
if decoder_attention_mask is None: if decoder_attention_mask is None:
@ -302,14 +401,16 @@ class Seq2SeqLMBatch(Batch):
start_index:end_index, :, -batch.max_input_length :, : start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :] ] = t[:, :, -batch.max_input_length :, :]
start_index += batch.size start_index += len(batch)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None, input_ids=None,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state, encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -319,7 +420,6 @@ class Seq2SeqLMBatch(Batch):
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -413,46 +513,25 @@ class Seq2SeqLM(Model):
else: else:
decoder_attention_mask = None decoder_attention_mask = None
# check if first forward or not
if batch.past_key_values is not None:
# Only take the last token
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
else:
decoder_input_ids = batch.decoder_input_ids
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]` # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally... # internally...
if batch.encoder_last_hidden_state is not None: if batch.encoder_last_hidden_state is not None:
encoder_last_hidden_state = [batch.encoder_last_hidden_state] encoder_last_hidden_state = [batch.encoder_last_hidden_state]
else: else:
encoder_last_hidden_state = batch.encoder_last_hidden_state encoder_last_hidden_state = None
logits, encoder_last_hidden_state, past = self.forward( logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, batch.attention_mask,
decoder_input_ids, batch.decoder_input_ids,
decoder_attention_mask, decoder_attention_mask,
encoder_last_hidden_state, encoder_last_hidden_state,
batch.past_key_values, batch.past_key_values,
) )
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
next_batch_max_decoder_input_length = 0
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -464,7 +543,7 @@ class Seq2SeqLM(Model):
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.decoder_input_ids, batch.all_decoder_input_ids,
) )
# For each member of the batch # For each member of the batch
@ -477,22 +556,24 @@ class Seq2SeqLM(Model):
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
decoder_input_ids, all_decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits all_decoder_input_ids.view(1, -1), logits
) )
# Append next token to decoder tokens # Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)]) all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, next_token_id.squeeze(1)]
)
new_decoder_input_length = decoder_input_length + 1 new_decoder_input_length = decoder_input_length + 1
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token( next_token_text, offset, token_offset = self.decode_token(
decoder_input_ids, offset, token_offset all_decoder_input_ids, offset, token_offset
) )
# Evaluate stopping criteria # Evaluate stopping criteria
@ -501,7 +582,7 @@ class Seq2SeqLM(Model):
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
output_text = self.decode(decoder_input_ids[-decoder_input_length:]) output_text = self.decode(all_decoder_input_ids[-decoder_input_length:])
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -515,19 +596,7 @@ class Seq2SeqLM(Model):
else: else:
# Keep request in the batch # Keep request in the batch
generated_text = None generated_text = None
next_batch_keep_indices.append(i) stopped = False
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
@ -551,69 +620,29 @@ class Seq2SeqLM(Model):
generations.append(generation) generations.append(generation)
# Update values
batch.decoder_input_ids[i] = next_token_id
batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length
batch.decoder_input_lengths[i] = new_decoder_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, input_length)
batch.max_decoder_input_length = max(
batch.max_decoder_input_length, new_decoder_input_length
)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if stopped:
return generations, None return generations, None
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) # We don't need input_ids after the prefill forward
# If we finished at least one generation, we need to evict the indices of the generations that finished batch.input_ids = None
# from the values of the next batch batch.encoder_last_hidden_state = encoder_last_hidden_state
if len(next_batch_keep_indices) != len(batch): batch.past_key_values = past
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
next_batch_keep_indices
]
else:
next_batch_decoder_attention_mask = None
next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
next_batch_keep_indices
]
next_batch_past_key_values = [
[t[next_batch_keep_indices] for t in layer] for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_attention_mask = batch.attention_mask
next_batch_decoder_attention_mask = batch.decoder_attention_mask
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update decoder_attention_mask as we added a new token to input_ids # Update decoder_attention_mask as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None: if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1 batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
next_batch = Seq2SeqLMBatch( return generations, batch
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=None,
attention_mask=next_batch_attention_mask,
decoder_input_ids=next_batch_decoder_input_ids,
decoder_attention_mask=next_batch_decoder_attention_mask,
encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length,
padding_right_offset=batch.padding_right_offset - 1,
)
return generations, next_batch

View File

@ -25,6 +25,10 @@ class Batch(ABC):
) -> "Batch": ) -> "Batch":
raise NotImplementedError raise NotImplementedError
@abstractmethod
def filter(self, requests: List[generate_pb2.Request]) -> "Batch":
raise NotImplementedError
@classmethod @classmethod
@abstractmethod @abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":

View File

@ -60,7 +60,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(batch_pb.id) batch = self.cache.pop(batch_pb.id)
if batch is None: if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch) batch = batch.filter(batch_pb.requests)
if batch is not None:
batches.append(batch)
if len(batches) == 0:
raise ValueError("All batches are empty")
if len(batches) > 1: if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches) batch = self.model.batch_type.concatenate(batches)