diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 67afa04e..ea7c9778 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -163,7 +163,7 @@ async fn prefill( // Run prefill let start_time = Instant::now(); - let (_, decode_batch) = client.prefill(batch.clone()).await?; + let (_, decode_batch, _) = client.prefill(batch.clone()).await?; // Get latency let latency = start_time.elapsed(); diff --git a/proto/generate.proto b/proto/generate.proto index 19ec059b..02f3b2e8 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -182,6 +182,12 @@ message PrefillResponse { repeated Generation generations = 1; /// Next batch (cached) optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; } message DecodeRequest { @@ -194,6 +200,14 @@ message DecodeResponse { repeated Generation generations = 1; /// Next batch (cached) optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; } message WarmupRequest { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 1560f19c..898e2b11 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -4,6 +4,7 @@ use crate::pb::generate::v2::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; use std::cmp::min; +use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -157,10 +158,14 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, - ) -> Result<(Vec, Option)> { + ) -> Result<(Vec, Option, PrefillTimings)> { let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let response = self.stub.prefill(request).await?.into_inner(); - Ok((response.generations, response.batch)) + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) } /// Generate one token for each request in the given cached batches @@ -171,9 +176,52 @@ impl Client { pub async fn decode( &mut self, batches: Vec, - ) -> Result<(Vec, Option)> { + ) -> Result<(Vec, Option, DecodeTimings)> { let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let response = self.stub.decode(request).await?.into_inner(); - Ok((response.generations, response.batch)) + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(|v| Duration::from_nanos(v)), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } } } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index b4bdcd42..6c5da3c7 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,3 +1,4 @@ +use crate::client::{DecodeTimings, PrefillTimings}; /// Multi shard Client use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; use crate::{ClientError, Result}; @@ -116,49 +117,63 @@ impl ShardedClient { /// /// Returns Generation for each request in batch /// and the next cached batch - #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] pub async fn prefill( &mut self, batch: Batch, - ) -> Result<(Vec, Option)> { + ) -> Result<(Vec, Option, PrefillTimings)> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| Box::pin(client.prefill(batch.clone()))) .collect(); - let results: Result, Option)>> = + let results: Result, Option, PrefillTimings)>> = join_all(futures).await.into_iter().collect(); - merge_generations(results?) + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) } /// Generate one token for each request in the given cached batches /// /// Returns Generation for each request in batches /// and the next cached batch - #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] pub async fn decode( &mut self, batches: Vec, - ) -> Result<(Vec, Option)> { + ) -> Result<(Vec, Option, DecodeTimings)> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); - let results: Result, Option)>> = + let results: Result, Option, DecodeTimings)>> = join_all(futures).await.into_iter().collect(); - merge_generations(results?) + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) } } - -/// Merge generations from the different model shards -fn merge_generations( - mut results: Vec<(Vec, Option)>, -) -> Result<(Vec, Option)> { - let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?; - - for (mut shard_generations, _) in results.into_iter() { - generations.append(&mut shard_generations); - } - Ok((generations, next_batch)) -} diff --git a/router/src/infer.rs b/router/src/infer.rs index 2e199ce2..bf5920da 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -379,15 +379,20 @@ async fn prefill( metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); match client.prefill(batch).await { - Ok((generations, next_batch)) => { + Ok((generations, next_batch, timings)) => { // Update health generation_health.store(true, Ordering::SeqCst); + + let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; + metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); next_batch @@ -416,15 +421,23 @@ async fn decode( metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); match client.decode(batches).await { - Ok((generations, next_batch)) => { + Ok((generations, next_batch, timings)) => { // Update health generation_health.store(true, Ordering::SeqCst); + + let start_filtering_time = Instant::now(); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + } + metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); next_batch diff --git a/router/src/validation.rs b/router/src/validation.rs index 1b47fc97..90dc3741 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -540,7 +540,7 @@ mod tests { let max_stop_sequence = 3; let max_top_n_tokens = 4; let max_input_length = 5; - let max_total_tokens = 6; + let max_total_tokens = 106; let workers = 1; let validation = Validation::new( workers, @@ -600,7 +600,7 @@ mod tests { let max_stop_sequences = 3; let max_top_n_tokens = 4; let max_input_length = 5; - let max_total_tokens = 6; + let max_total_tokens = 106; let workers = 1; let validation = Validation::new( workers, diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index d9a33795..66df708a 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -103,7 +103,7 @@ def test_causal_lm_batch_type(default_bloom): def test_causal_lm_generate_token(default_bloom, default_bloom_batch): sequence_length = len(default_bloom_batch.all_input_ids[0]) - generations, next_batch = default_bloom.generate_token(default_bloom_batch) + generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch) assert len(generations) == len(default_bloom_batch) assert isinstance(next_batch, CausalLMBatch) @@ -153,10 +153,10 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): next_batch = default_bloom_batch for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert len(generations) == len(default_bloom_batch) - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -178,10 +178,10 @@ def test_causal_lm_generate_token_completion_multi( for i in range( default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert len(generations) == len(default_multi_requests_bloom_batch) - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 @@ -201,10 +201,10 @@ def test_causal_lm_generate_token_completion_multi( for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -224,11 +224,11 @@ def test_batch_concatenate( default_bloom, default_bloom_batch, default_multi_requests_bloom_batch ): next_batch_0 = default_bloom_batch - _, next_batch_0 = default_bloom.generate_token(next_batch_0) - _, next_batch_0 = default_bloom.generate_token(next_batch_0) + _, next_batch_0, _ = default_bloom.generate_token(next_batch_0) + _, next_batch_0, _ = default_bloom.generate_token(next_batch_0) next_batch_1 = default_multi_requests_bloom_batch - _, next_batch_1 = default_bloom.generate_token(next_batch_1) + _, next_batch_1, _ = default_bloom.generate_token(next_batch_1) # Clone past_key_values before concatenating to compare after, # because they are removed from the concatenated batches @@ -288,10 +288,10 @@ def test_batch_concatenate( for _ in range( default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert next_batch is not None assert len(generations) == 3 @@ -313,10 +313,10 @@ def test_batch_concatenate( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 @@ -337,10 +337,10 @@ def test_batch_concatenate( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 4 ): - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch, _ = default_bloom.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 8b45e781..250fa354 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -99,7 +99,9 @@ def test_causal_lm_batch_type(default_causal_lm): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): sequence_length = len(default_causal_lm_batch.all_input_ids[0]) - generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) + generations, next_batch, _ = default_causal_lm.generate_token( + default_causal_lm_batch + ) assert len(generations) == len(next_batch) assert isinstance(next_batch, CausalLMBatch) @@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion( ): next_batch = default_causal_lm_batch for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -174,10 +176,10 @@ def test_causal_lm_generate_token_completion_multi( for i in range( default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 @@ -200,10 +202,10 @@ def test_causal_lm_generate_token_completion_multi( for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -222,11 +224,11 @@ def test_batch_concatenate( default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch ): next_batch_0 = default_causal_lm_batch - _, next_batch_0 = default_causal_lm.generate_token(next_batch_0) - _, next_batch_0 = default_causal_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0) next_batch_1 = default_multi_requests_causal_lm_batch - _, next_batch_1 = default_causal_lm.generate_token(next_batch_1) + _, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1) # Clone past_key_values before concatenating to compare after, # because they are removed from the concatenated batches @@ -285,10 +287,10 @@ def test_batch_concatenate( for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 3 @@ -311,10 +313,10 @@ def test_batch_concatenate( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 @@ -333,10 +335,10 @@ def test_batch_concatenate( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 4 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch, _ = default_causal_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index fceec560..1e40e766 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -55,10 +55,10 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat next_batch = batch for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): - generations, next_batch = default_santacoder.generate_token(next_batch) + generations, next_batch, _ = default_santacoder.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_santacoder.generate_token(next_batch) + generations, next_batch, _ = default_santacoder.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -83,10 +83,10 @@ def test_fim_santacoder_generate_token_completion( next_batch = batch for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): - generations, next_batch = default_santacoder.generate_token(next_batch) + generations, next_batch, _ = default_santacoder.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_santacoder.generate_token(next_batch) + generations, next_batch, _ = default_santacoder.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 373867c7..735ab5eb 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) - generations, next_batch = default_seq2seq_lm.generate_token( + generations, next_batch, _ = default_seq2seq_lm.generate_token( default_seq2seq_lm_batch ) @@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion( ): next_batch = default_seq2seq_lm_batch for _ in range(6): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi( next_batch = default_multi_requests_seq2seq_lm_batch for i in range(4): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 @@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi( next_batch = next_batch.filter([next_batch.requests[0].id]) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 @@ -228,11 +228,11 @@ def test_batch_concatenate( default_multi_requests_seq2seq_lm_batch, ): next_batch_0 = default_seq2seq_lm_batch - _, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) - _, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0) + _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0) next_batch_1 = default_multi_requests_seq2seq_lm_batch - _, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1) + _, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1) # Copy hidden state because it is removed from the concatenated branches next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state @@ -324,10 +324,10 @@ def test_batch_concatenate( ) for _ in range(3): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 3 @@ -342,7 +342,7 @@ def test_batch_concatenate( [next_batch.requests[0].id, next_batch.requests[1].id] ) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None assert len(generations) == 2 @@ -352,7 +352,7 @@ def test_batch_concatenate( next_batch = next_batch.filter([next_batch.requests[1].id]) - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None assert len(generations) == 1 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index b771264b..7b10256c 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,6 +1,5 @@ -from text_generation_server.utils.tokens import batch_top_tokens import torch -import inspect +import time from dataclasses import dataclass from opentelemetry import trace @@ -8,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model +from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, Tokens, @@ -564,7 +564,8 @@ class CausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( self, batch: CausalLMBatch - ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: + start = time.time_ns() # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] @@ -585,6 +586,8 @@ class CausalLM(Model): torch.log_softmax(logits[:, -1], -1), ) + start_decode = time.time_ns() + # Zipped iterator iterator = zip( batch.requests, @@ -731,7 +734,9 @@ class CausalLM(Model): # We finished all generations in the batch; there is no next batch if stopped: - return generations, None + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) # Slice unused values from prefill batch.input_ids = batch.input_ids[:, :1] @@ -747,4 +752,6 @@ class CausalLM(Model): # Update past key values batch.past_key_values = past - return generations, batch + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 14d30635..930082cd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,6 +1,6 @@ import math +import time import itertools -from text_generation_server.utils.tokens import batch_top_tokens import torch import torch.distributed @@ -9,9 +9,10 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type, Union, Dict +from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model +from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.speculate import get_speculate from text_generation_server.models.types import ( Batch, @@ -689,7 +690,7 @@ class FlashCausalLM(Model): self.dtype, self.device, ) - _, batch = self.generate_token(batch) + _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " @@ -799,7 +800,8 @@ class FlashCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( self, batch: FlashCausalLMBatch - ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: + ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: + start = time.time_ns() prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None @@ -941,6 +943,8 @@ class FlashCausalLM(Model): # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() + accepted_ids = accepted_ids.tolist() + start_decode = time.time_ns() # Zipped iterator iterator = zip( @@ -977,7 +981,6 @@ class FlashCausalLM(Model): # Append next token to all tokens next_token_texts = [] left = 0 - before = stopping_criteria.current_tokens current_stopped = False for j in range(index, index + n_accepted_ids): @@ -1092,7 +1095,7 @@ class FlashCausalLM(Model): generations.append(generation) # Update values - batch.input_lengths[i] = input_length + n_accepted_ids.item() + batch.input_lengths[i] = input_length + n_accepted_ids if batch.input_lengths[i] > batch.max_seqlen: batch.max_seqlen = batch.input_lengths[i] batch.prefix_offsets[i] = prefix_offset @@ -1102,10 +1105,14 @@ class FlashCausalLM(Model): if stopped: del batch # No need to return a batch if we know that all requests stopped - return generations, None + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None - return generations, batch + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 86389ad2..2f28688d 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -1,17 +1,11 @@ import torch -import inspect -import re -from io import BytesIO -import base64 -from PIL import Image -import re +import time from dataclasses import dataclass from opentelemetry import trace from transformers import ( AutoProcessor, AutoTokenizer, - AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin, ) @@ -670,7 +664,8 @@ class IdeficsCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( self, batch: IdeficsCausalLMBatch - ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]: + ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: + start = time.time_ns() # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] if batch.input_ids.size(1) == 1: @@ -699,6 +694,8 @@ class IdeficsCausalLM(Model): # Hardcoded remove image tokens logits[:, 32000:32001] = torch.finfo(logits.dtype).min + start_decode = time.time_ns() + # Results generations: List[Generation] = [] stopped = True @@ -827,7 +824,9 @@ class IdeficsCausalLM(Model): # We finished all generations in the batch; there is no next batch if stopped: - return generations, None + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) # Slice unused values from prefill batch.input_ids = batch.input_ids[:, :1] @@ -847,4 +846,6 @@ class IdeficsCausalLM(Model): batch.past_key_values = past batch.image_hidden_states = image_hidden_states - return generations, batch + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index dfb21dcb..cb358672 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -65,7 +65,9 @@ class Model(ABC): raise NotImplementedError @abstractmethod - def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]: + def generate_token( + self, batch: B + ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: raise NotImplementedError def warmup(self, batch: B) -> Optional[int]: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index a85ef58e..f2e4cec6 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,11 +1,12 @@ -from text_generation_server.utils.tokens import batch_top_tokens import torch +import time from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( GeneratedText, @@ -613,7 +614,8 @@ class Seq2SeqLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( self, batch: Seq2SeqLMBatch - ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: + ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]: + start = time.time_ns() if batch.decoder_attention_mask is not None: # slice to the correct shape decoder_attention_mask = batch.decoder_attention_mask[ @@ -644,6 +646,8 @@ class Seq2SeqLM(Model): torch.log_softmax(logits[:, -1], -1), ) + start_decode = time.time_ns() + # Finished requests generations: List[Generation] = [] stopped = True @@ -788,7 +792,9 @@ class Seq2SeqLM(Model): # We finished all generations in the batch; there is no next batch if stopped: - return generations, None + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) # We don't need input_ids after the prefill forward batch.input_ids = None @@ -799,4 +805,6 @@ class Seq2SeqLM(Model): batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1 batch.padding_right_offset -= 1 - return generations, batch + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 75dba972..a65138c9 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -1,6 +1,7 @@ import asyncio import os import torch +import time from grpc import aio from loguru import logger @@ -76,6 +77,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) async def Prefill(self, request, context): + start = time.time_ns() if ( self.model.batch_type == IdeficsCausalLMBatch ): # Hack, i would rather use kwargs in the `from_pb` call @@ -91,15 +93,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - generations, next_batch = self.model.generate_token(batch) + generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, + forward_ns=timings[0], + decode_ns=timings[1], + total_ns=time.time_ns() - start, ) async def Decode(self, request, context): + start = time.time_ns() if len(request.batches) == 0: raise ValueError("Must provide at least one batch") @@ -114,16 +120,23 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): raise ValueError("All batches are empty") if len(batches) > 1: + start_concat = time.time_ns() batch = self.model.batch_type.concatenate(batches) + concat_ns = time.time_ns() - start_concat else: batch = batches[0] + concat_ns = None - generations, next_batch = self.model.generate_token(batch) + generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, + concat_ns=concat_ns, + forward_ns=timings[0], + decode_ns=timings[1], + total_ns=time.time_ns() - start, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0d208104..ff0556df 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -92,7 +92,7 @@ class NextTokenChooser: class StopSequenceCriteria: def __init__(self, stop_sequence: str): stop_sequence = re.escape(stop_sequence) - self.regex = re.compile(f".*{stop_sequence}$") + self.regex = re.compile(f"{stop_sequence}$") def __call__(self, output: str) -> bool: if self.regex.findall(output):