feat: add more latency metrics in forward (#1346)

This commit is contained in:
OlivierDehaene 2023-12-14 15:59:38 +01:00 committed by GitHub
parent 44b267ab22
commit 50b495f3d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 240 additions and 110 deletions

View File

@ -163,7 +163,7 @@ async fn prefill(
// Run prefill // Run prefill
let start_time = Instant::now(); let start_time = Instant::now();
let (_, decode_batch) = client.prefill(batch.clone()).await?; let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
// Get latency // Get latency
let latency = start_time.elapsed(); let latency = start_time.elapsed();

View File

@ -182,6 +182,12 @@ message PrefillResponse {
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
} }
message DecodeRequest { message DecodeRequest {
@ -194,6 +200,14 @@ message DecodeResponse {
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
} }
message WarmupRequest { message WarmupRequest {

View File

@ -4,6 +4,7 @@ use crate::pb::generate::v2::*;
use crate::Result; use crate::Result;
use grpc_metadata::InjectTelemetryContext; use grpc_metadata::InjectTelemetryContext;
use std::cmp::min; use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
@ -157,10 +158,14 @@ impl Client {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch)) Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
} }
/// Generate one token for each request in the given cached batches /// Generate one token for each request in the given cached batches
@ -171,9 +176,52 @@ impl Client {
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<CachedBatch>, batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner(); let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch)) Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(|v| Duration::from_nanos(v)),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
} }
} }

View File

@ -1,3 +1,4 @@
use crate::client::{DecodeTimings, PrefillTimings};
/// Multi shard Client /// Multi shard Client
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
use crate::{ClientError, Result}; use crate::{ClientError, Result};
@ -120,15 +121,28 @@ impl ShardedClient {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone())))
.collect(); .collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect(); join_all(futures).await.into_iter().collect();
merge_generations(results?) let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
} }
/// Generate one token for each request in the given cached batches /// Generate one token for each request in the given cached batches
@ -139,26 +153,27 @@ impl ShardedClient {
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<CachedBatch>, batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.decode(batches.clone()))) .map(|client| Box::pin(client.decode(batches.clone())))
.collect(); .collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect(); join_all(futures).await.into_iter().collect();
merge_generations(results?) let mut results = results?;
}
}
/// Merge generations from the different model shards let (mut generations, next_batch, mut timings) =
fn merge_generations( results.pop().ok_or(ClientError::EmptyResults)?;
mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
for (mut shard_generations, _) in results.into_iter() { // Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations); generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
} }
Ok((generations, next_batch))
} }

View File

@ -379,15 +379,20 @@ async fn prefill(
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch, timings)) => {
// Update health // Update health
generation_health.store(true, Ordering::SeqCst); generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch next_batch
@ -416,15 +421,23 @@ async fn decode(
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch, timings)) => {
// Update health // Update health
generation_health.store(true, Ordering::SeqCst); generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
}
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch next_batch

View File

@ -540,7 +540,7 @@ mod tests {
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_top_n_tokens = 4; let max_top_n_tokens = 4;
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 6; let max_total_tokens = 106;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
@ -600,7 +600,7 @@ mod tests {
let max_stop_sequences = 3; let max_stop_sequences = 3;
let max_top_n_tokens = 4; let max_top_n_tokens = 4;
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 6; let max_total_tokens = 106;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,

View File

@ -103,7 +103,7 @@ def test_causal_lm_batch_type(default_bloom):
def test_causal_lm_generate_token(default_bloom, default_bloom_batch): def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0]) sequence_length = len(default_bloom_batch.all_input_ids[0])
generations, next_batch = default_bloom.generate_token(default_bloom_batch) generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)
assert len(generations) == len(default_bloom_batch) assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
@ -153,10 +153,10 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_bloom_batch) assert len(generations) == len(default_bloom_batch)
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -178,10 +178,10 @@ def test_causal_lm_generate_token_completion_multi(
for i in range( for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
): ):
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_multi_requests_bloom_batch) assert len(generations) == len(default_multi_requests_bloom_batch)
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
@ -201,10 +201,10 @@ def test_causal_lm_generate_token_completion_multi(
for _ in range( for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
): ):
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -224,11 +224,11 @@ def test_batch_concatenate(
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
): ):
next_batch_0 = default_bloom_batch next_batch_0 = default_bloom_batch
_, next_batch_0 = default_bloom.generate_token(next_batch_0) _, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
_, next_batch_0 = default_bloom.generate_token(next_batch_0) _, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_bloom_batch next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1 = default_bloom.generate_token(next_batch_1) _, next_batch_1, _ = default_bloom.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after, # Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches # because they are removed from the concatenated batches
@ -288,10 +288,10 @@ def test_batch_concatenate(
for _ in range( for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
): ):
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
@ -313,10 +313,10 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2 - 2
): ):
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
@ -337,10 +337,10 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4 - 4
): ):
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch) generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1

View File

@ -99,7 +99,9 @@ def test_causal_lm_batch_type(default_causal_lm):
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0]) sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) generations, next_batch, _ = default_causal_lm.generate_token(
default_causal_lm_batch
)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion(
): ):
next_batch = default_causal_lm_batch next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -174,10 +176,10 @@ def test_causal_lm_generate_token_completion_multi(
for i in range( for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
): ):
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
@ -200,10 +202,10 @@ def test_causal_lm_generate_token_completion_multi(
for _ in range( for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
): ):
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -222,11 +224,11 @@ def test_batch_concatenate(
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
): ):
next_batch_0 = default_causal_lm_batch next_batch_0 = default_causal_lm_batch
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0) _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0) _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_causal_lm_batch next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1) _, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after, # Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches # because they are removed from the concatenated batches
@ -285,10 +287,10 @@ def test_batch_concatenate(
for _ in range( for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
): ):
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
@ -311,10 +313,10 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2 - 2
): ):
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
@ -333,10 +335,10 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4 - 4
): ):
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch) generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1

View File

@ -55,10 +55,10 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
next_batch = batch next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -83,10 +83,10 @@ def test_fim_santacoder_generate_token_completion(
next_batch = batch next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_santacoder.generate_token(next_batch) generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1

View File

@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch = default_seq2seq_lm.generate_token( generations, next_batch, _ = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch default_seq2seq_lm_batch
) )
@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
): ):
next_batch = default_seq2seq_lm_batch next_batch = default_seq2seq_lm_batch
for _ in range(6): for _ in range(6):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = default_multi_requests_seq2seq_lm_batch next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4): for i in range(4):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0].id]) next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
@ -228,11 +228,11 @@ def test_batch_concatenate(
default_multi_requests_seq2seq_lm_batch, default_multi_requests_seq2seq_lm_batch,
): ):
next_batch_0 = default_seq2seq_lm_batch next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1) _, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
# Copy hidden state because it is removed from the concatenated branches # Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
@ -324,10 +324,10 @@ def test_batch_concatenate(
) )
for _ in range(3): for _ in range(3):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
@ -342,7 +342,7 @@ def test_batch_concatenate(
[next_batch.requests[0].id, next_batch.requests[1].id] [next_batch.requests[0].id, next_batch.requests[1].id]
) )
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
@ -352,7 +352,7 @@ def test_batch_concatenate(
next_batch = next_batch.filter([next_batch.requests[1].id]) next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1

View File

@ -1,6 +1,5 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import inspect import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
@ -8,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -564,7 +564,8 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
@ -585,6 +586,8 @@ class CausalLM(Model):
torch.log_softmax(logits[:, -1], -1), torch.log_softmax(logits[:, -1], -1),
) )
start_decode = time.time_ns()
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
@ -731,7 +734,9 @@ class CausalLM(Model):
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if stopped: if stopped:
return generations, None forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill # Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1] batch.input_ids = batch.input_ids[:, :1]
@ -747,4 +752,6 @@ class CausalLM(Model):
# Update past key values # Update past key values
batch.past_key_values = past batch.past_key_values = past
return generations, batch forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,6 +1,6 @@
import math import math
import time
import itertools import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import torch.distributed import torch.distributed
@ -9,9 +9,10 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -689,7 +690,7 @@ class FlashCausalLM(Model):
self.dtype, self.dtype,
self.device, self.device,
) )
_, batch = self.generate_token(batch) _, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
@ -799,7 +800,8 @@ class FlashCausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
prefill = batch.cu_seqlen_prefill is not None prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
@ -941,6 +943,8 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync # GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist() next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
start_decode = time.time_ns()
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -977,7 +981,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens # Append next token to all tokens
next_token_texts = [] next_token_texts = []
left = 0 left = 0
before = stopping_criteria.current_tokens
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
@ -1092,7 +1095,7 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.input_lengths[i] = input_length + n_accepted_ids.item() batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen: if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i] batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
@ -1102,10 +1105,14 @@ class FlashCausalLM(Model):
if stopped: if stopped:
del batch del batch
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped
return generations, None forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None batch.prefill_cu_outlens = None
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
return generations, batch forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,17 +1,11 @@
import torch import torch
import inspect import time
import re
from io import BytesIO
import base64
from PIL import Image
import re
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
ProcessorMixin, ProcessorMixin,
) )
@ -670,7 +664,8 @@ class IdeficsCausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: IdeficsCausalLMBatch self, batch: IdeficsCausalLMBatch
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.input_ids.size(1) == 1: if batch.input_ids.size(1) == 1:
@ -699,6 +694,8 @@ class IdeficsCausalLM(Model):
# Hardcoded remove image tokens # Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min logits[:, 32000:32001] = torch.finfo(logits.dtype).min
start_decode = time.time_ns()
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
@ -827,7 +824,9 @@ class IdeficsCausalLM(Model):
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if stopped: if stopped:
return generations, None forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill # Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1] batch.input_ids = batch.input_ids[:, :1]
@ -847,4 +846,6 @@ class IdeficsCausalLM(Model):
batch.past_key_values = past batch.past_key_values = past
batch.image_hidden_states = image_hidden_states batch.image_hidden_states = image_hidden_states
return generations, batch forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -65,7 +65,9 @@ class Model(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]: def generate_token(
self, batch: B
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]: def warmup(self, batch: B) -> Optional[int]:

View File

@ -1,11 +1,12 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
GeneratedText, GeneratedText,
@ -613,7 +614,8 @@ class Seq2SeqLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: Seq2SeqLMBatch self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
start = time.time_ns()
if batch.decoder_attention_mask is not None: if batch.decoder_attention_mask is not None:
# slice to the correct shape # slice to the correct shape
decoder_attention_mask = batch.decoder_attention_mask[ decoder_attention_mask = batch.decoder_attention_mask[
@ -644,6 +646,8 @@ class Seq2SeqLM(Model):
torch.log_softmax(logits[:, -1], -1), torch.log_softmax(logits[:, -1], -1),
) )
start_decode = time.time_ns()
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
@ -788,7 +792,9 @@ class Seq2SeqLM(Model):
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if stopped: if stopped:
return generations, None forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# We don't need input_ids after the prefill forward # We don't need input_ids after the prefill forward
batch.input_ids = None batch.input_ids = None
@ -799,4 +805,6 @@ class Seq2SeqLM(Model):
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1 batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1
return generations, batch forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
import torch import torch
import time
from grpc import aio from grpc import aio
from loguru import logger from loguru import logger
@ -76,6 +77,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
) )
async def Prefill(self, request, context): async def Prefill(self, request, context):
start = time.time_ns()
if ( if (
self.model.batch_type == IdeficsCausalLMBatch self.model.batch_type == IdeficsCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call ): # Hack, i would rather use kwargs in the `from_pb` call
@ -91,15 +93,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
generations, next_batch = self.model.generate_token(batch) generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
) )
async def Decode(self, request, context): async def Decode(self, request, context):
start = time.time_ns()
if len(request.batches) == 0: if len(request.batches) == 0:
raise ValueError("Must provide at least one batch") raise ValueError("Must provide at least one batch")
@ -114,16 +120,23 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise ValueError("All batches are empty") raise ValueError("All batches are empty")
if len(batches) > 1: if len(batches) > 1:
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches) batch = self.model.batch_type.concatenate(batches)
concat_ns = time.time_ns() - start_concat
else: else:
batch = batches[0] batch = batches[0]
concat_ns = None
generations, next_batch = self.model.generate_token(batch) generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
concat_ns=concat_ns,
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
) )

View File

@ -92,7 +92,7 @@ class NextTokenChooser:
class StopSequenceCriteria: class StopSequenceCriteria:
def __init__(self, stop_sequence: str): def __init__(self, stop_sequence: str):
stop_sequence = re.escape(stop_sequence) stop_sequence = re.escape(stop_sequence)
self.regex = re.compile(f".*{stop_sequence}$") self.regex = re.compile(f"{stop_sequence}$")
def __call__(self, output: str) -> bool: def __call__(self, output: str) -> bool:
if self.regex.findall(output): if self.regex.findall(output):