feat: add more latency metrics in forward (#1346)
This commit is contained in:
parent
44b267ab22
commit
50b495f3d8
|
@ -163,7 +163,7 @@ async fn prefill(
|
|||
|
||||
// Run prefill
|
||||
let start_time = Instant::now();
|
||||
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
||||
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
||||
|
||||
// Get latency
|
||||
let latency = start_time.elapsed();
|
||||
|
|
|
@ -182,6 +182,12 @@ message PrefillResponse {
|
|||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
/// Forward elapsed time in nanoseconds
|
||||
uint64 forward_ns = 3;
|
||||
/// Decode elapsed time in nanoseconds
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
}
|
||||
|
||||
message DecodeRequest {
|
||||
|
@ -194,6 +200,14 @@ message DecodeResponse {
|
|||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
/// Forward elapsed time in nanoseconds
|
||||
uint64 forward_ns = 3;
|
||||
/// Decode elapsed time in nanoseconds
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
/// Concatenate elapsed time in nanoseconds
|
||||
optional uint64 concat_ns = 6;
|
||||
}
|
||||
|
||||
message WarmupRequest {
|
||||
|
|
|
@ -4,6 +4,7 @@ use crate::pb::generate::v2::*;
|
|||
use crate::Result;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
|
@ -157,10 +158,14 @@ impl Client {
|
|||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((response.generations, response.batch))
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
|
@ -171,9 +176,52 @@ impl Client {
|
|||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((response.generations, response.batch))
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecodeTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(|v| Duration::from_nanos(v)),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::client::{DecodeTimings, PrefillTimings};
|
||||
/// Multi shard Client
|
||||
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
|
@ -116,49 +117,63 @@ impl ShardedClient {
|
|||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
merge_generations(results?)
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
merge_generations(results?)
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge generations from the different model shards
|
||||
fn merge_generations(
|
||||
mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
for (mut shard_generations, _) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
}
|
||||
Ok((generations, next_batch))
|
||||
}
|
||||
|
|
|
@ -379,15 +379,20 @@ async fn prefill(
|
|||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||
next_batch
|
||||
|
@ -416,15 +421,23 @@ async fn decode(
|
|||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
next_batch
|
||||
|
|
|
@ -540,7 +540,7 @@ mod tests {
|
|||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let max_total_tokens = 106;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
|
@ -600,7 +600,7 @@ mod tests {
|
|||
let max_stop_sequences = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let max_total_tokens = 106;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_causal_lm_batch_type(default_bloom):
|
|||
|
||||
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
||||
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)
|
||||
|
||||
assert len(generations) == len(default_bloom_batch)
|
||||
assert isinstance(next_batch, CausalLMBatch)
|
||||
|
@ -153,10 +153,10 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
||||
next_batch = default_bloom_batch
|
||||
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(default_bloom_batch)
|
||||
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
@ -178,10 +178,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
for i in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(default_multi_requests_bloom_batch)
|
||||
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
|
@ -201,10 +201,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
@ -224,11 +224,11 @@ def test_batch_concatenate(
|
|||
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
|
||||
):
|
||||
next_batch_0 = default_bloom_batch
|
||||
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
|
||||
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
|
||||
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
|
||||
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
|
||||
|
||||
next_batch_1 = default_multi_requests_bloom_batch
|
||||
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
|
||||
_, next_batch_1, _ = default_bloom.generate_token(next_batch_1)
|
||||
|
||||
# Clone past_key_values before concatenating to compare after,
|
||||
# because they are removed from the concatenated batches
|
||||
|
@ -288,10 +288,10 @@ def test_batch_concatenate(
|
|||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 3
|
||||
|
@ -313,10 +313,10 @@ def test_batch_concatenate(
|
|||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
- 2
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
|
@ -337,10 +337,10 @@ def test_batch_concatenate(
|
|||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
- 4
|
||||
):
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
|
|
@ -99,7 +99,9 @@ def test_causal_lm_batch_type(default_causal_lm):
|
|||
|
||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
||||
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(
|
||||
default_causal_lm_batch
|
||||
)
|
||||
|
||||
assert len(generations) == len(next_batch)
|
||||
assert isinstance(next_batch, CausalLMBatch)
|
||||
|
@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion(
|
|||
):
|
||||
next_batch = default_causal_lm_batch
|
||||
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
@ -174,10 +176,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
for i in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
|
@ -200,10 +202,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
@ -222,11 +224,11 @@ def test_batch_concatenate(
|
|||
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
|
||||
):
|
||||
next_batch_0 = default_causal_lm_batch
|
||||
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
|
||||
|
||||
next_batch_1 = default_multi_requests_causal_lm_batch
|
||||
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
|
||||
_, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)
|
||||
|
||||
# Clone past_key_values before concatenating to compare after,
|
||||
# because they are removed from the concatenated batches
|
||||
|
@ -285,10 +287,10 @@ def test_batch_concatenate(
|
|||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 3
|
||||
|
@ -311,10 +313,10 @@ def test_batch_concatenate(
|
|||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
- 2
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
|
@ -333,10 +335,10 @@ def test_batch_concatenate(
|
|||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
- 4
|
||||
):
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
|
|
@ -55,10 +55,10 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
|||
next_batch = batch
|
||||
|
||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
@ -83,10 +83,10 @@ def test_fim_santacoder_generate_token_completion(
|
|||
next_batch = batch
|
||||
|
||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
|||
|
||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(
|
||||
default_seq2seq_lm_batch
|
||||
)
|
||||
|
||||
|
@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
|
|||
):
|
||||
next_batch = default_seq2seq_lm_batch
|
||||
for _ in range(6):
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||
next_batch = default_multi_requests_seq2seq_lm_batch
|
||||
|
||||
for i in range(4):
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
|
@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
@ -228,11 +228,11 @@ def test_batch_concatenate(
|
|||
default_multi_requests_seq2seq_lm_batch,
|
||||
):
|
||||
next_batch_0 = default_seq2seq_lm_batch
|
||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
|
||||
next_batch_1 = default_multi_requests_seq2seq_lm_batch
|
||||
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
|
||||
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
|
||||
|
||||
# Copy hidden state because it is removed from the concatenated branches
|
||||
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
|
||||
|
@ -324,10 +324,10 @@ def test_batch_concatenate(
|
|||
)
|
||||
|
||||
for _ in range(3):
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 3
|
||||
|
@ -342,7 +342,7 @@ def test_batch_concatenate(
|
|||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||
)
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
|
@ -352,7 +352,7 @@ def test_batch_concatenate(
|
|||
|
||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import inspect
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
|
@ -8,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
|||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
|
@ -564,7 +564,8 @@ class CausalLM(Model):
|
|||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: CausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
|
||||
|
@ -585,6 +586,8 @@ class CausalLM(Model):
|
|||
torch.log_softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
|
@ -731,7 +734,9 @@ class CausalLM(Model):
|
|||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
return generations, None
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
@ -747,4 +752,6 @@ class CausalLM(Model):
|
|||
# Update past key values
|
||||
batch.past_key_values = past
|
||||
|
||||
return generations, batch
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
import time
|
||||
import itertools
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -9,9 +9,10 @@ import numpy as np
|
|||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
|
@ -689,7 +690,7 @@ class FlashCausalLM(Model):
|
|||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
_, batch = self.generate_token(batch)
|
||||
_, batch, _ = self.generate_token(batch)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||
|
@ -799,7 +800,8 @@ class FlashCausalLM(Model):
|
|||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
prefill = batch.cu_seqlen_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
|
@ -941,6 +943,8 @@ class FlashCausalLM(Model):
|
|||
# GPU <-> CPU sync
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
accepted_ids = accepted_ids.tolist()
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
|
@ -977,7 +981,6 @@ class FlashCausalLM(Model):
|
|||
# Append next token to all tokens
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
before = stopping_criteria.current_tokens
|
||||
|
||||
current_stopped = False
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
|
@ -1092,7 +1095,7 @@ class FlashCausalLM(Model):
|
|||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.input_lengths[i] = input_length + n_accepted_ids.item()
|
||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||
if batch.input_lengths[i] > batch.max_seqlen:
|
||||
batch.max_seqlen = batch.input_lengths[i]
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
|
@ -1102,10 +1105,14 @@ class FlashCausalLM(Model):
|
|||
if stopped:
|
||||
del batch
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
return generations, None
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
|
||||
return generations, batch
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
import torch
|
||||
import inspect
|
||||
import re
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from PIL import Image
|
||||
import re
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
@ -670,7 +664,8 @@ class IdeficsCausalLM(Model):
|
|||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: IdeficsCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
|
||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
|
@ -699,6 +694,8 @@ class IdeficsCausalLM(Model):
|
|||
# Hardcoded remove image tokens
|
||||
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
@ -827,7 +824,9 @@ class IdeficsCausalLM(Model):
|
|||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
return generations, None
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# Slice unused values from prefill
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
@ -847,4 +846,6 @@ class IdeficsCausalLM(Model):
|
|||
batch.past_key_values = past
|
||||
batch.image_hidden_states = image_hidden_states
|
||||
|
||||
return generations, batch
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
|
|
@ -65,7 +65,9 @@ class Model(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
|
||||
def generate_token(
|
||||
self, batch: B
|
||||
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def warmup(self, batch: B) -> Optional[int]:
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
|
@ -613,7 +614,8 @@ class Seq2SeqLM(Model):
|
|||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: Seq2SeqLMBatch
|
||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
if batch.decoder_attention_mask is not None:
|
||||
# slice to the correct shape
|
||||
decoder_attention_mask = batch.decoder_attention_mask[
|
||||
|
@ -644,6 +646,8 @@ class Seq2SeqLM(Model):
|
|||
torch.log_softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Finished requests
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
@ -788,7 +792,9 @@ class Seq2SeqLM(Model):
|
|||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
return generations, None
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
# We don't need input_ids after the prefill forward
|
||||
batch.input_ids = None
|
||||
|
@ -799,4 +805,6 @@ class Seq2SeqLM(Model):
|
|||
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
return generations, batch
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
|
||||
from grpc import aio
|
||||
from loguru import logger
|
||||
|
@ -76,6 +77,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
)
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
start = time.time_ns()
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
|
@ -91,15 +93,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
forward_ns=timings[0],
|
||||
decode_ns=timings[1],
|
||||
total_ns=time.time_ns() - start,
|
||||
)
|
||||
|
||||
async def Decode(self, request, context):
|
||||
start = time.time_ns()
|
||||
if len(request.batches) == 0:
|
||||
raise ValueError("Must provide at least one batch")
|
||||
|
||||
|
@ -114,16 +120,23 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
raise ValueError("All batches are empty")
|
||||
|
||||
if len(batches) > 1:
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
concat_ns = time.time_ns() - start_concat
|
||||
else:
|
||||
batch = batches[0]
|
||||
concat_ns = None
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
concat_ns=concat_ns,
|
||||
forward_ns=timings[0],
|
||||
decode_ns=timings[1],
|
||||
total_ns=time.time_ns() - start,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ class NextTokenChooser:
|
|||
class StopSequenceCriteria:
|
||||
def __init__(self, stop_sequence: str):
|
||||
stop_sequence = re.escape(stop_sequence)
|
||||
self.regex = re.compile(f".*{stop_sequence}$")
|
||||
self.regex = re.compile(f"{stop_sequence}$")
|
||||
|
||||
def __call__(self, output: str) -> bool:
|
||||
if self.regex.findall(output):
|
||||
|
|
Loading…
Reference in New Issue