feat: add more latency metrics in forward (#1346)

This commit is contained in:
OlivierDehaene 2023-12-14 15:59:38 +01:00 committed by GitHub
parent 44b267ab22
commit 50b495f3d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 240 additions and 110 deletions

View File

@ -163,7 +163,7 @@ async fn prefill(
// Run prefill
let start_time = Instant::now();
let (_, decode_batch) = client.prefill(batch.clone()).await?;
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
// Get latency
let latency = start_time.elapsed();

View File

@ -182,6 +182,12 @@ message PrefillResponse {
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
}
message DecodeRequest {
@ -194,6 +200,14 @@ message DecodeResponse {
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message WarmupRequest {

View File

@ -4,6 +4,7 @@ use crate::pb::generate::v2::*;
use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
@ -157,10 +158,14 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch))
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
@ -171,9 +176,52 @@ impl Client {
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch))
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(|v| Duration::from_nanos(v)),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}

View File

@ -1,3 +1,4 @@
use crate::client::{DecodeTimings, PrefillTimings};
/// Multi shard Client
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
use crate::{ClientError, Result};
@ -116,49 +117,63 @@ impl ShardedClient {
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
merge_generations(results?)
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
merge_generations(results?)
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
/// Merge generations from the different model shards
fn merge_generations(
mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
for (mut shard_generations, _) in results.into_iter() {
generations.append(&mut shard_generations);
}
Ok((generations, next_batch))
}

View File

@ -379,15 +379,20 @@ async fn prefill(
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch
@ -416,15 +421,23 @@ async fn decode(
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await {
Ok((generations, next_batch)) => {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
}
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch

View File

@ -540,7 +540,7 @@ mod tests {
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let max_total_tokens = 106;
let workers = 1;
let validation = Validation::new(
workers,
@ -600,7 +600,7 @@ mod tests {
let max_stop_sequences = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let max_total_tokens = 106;
let workers = 1;
let validation = Validation::new(
workers,

View File

@ -103,7 +103,7 @@ def test_causal_lm_batch_type(default_bloom):
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0])
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)
assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch)
@ -153,10 +153,10 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_bloom_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -178,10 +178,10 @@ def test_causal_lm_generate_token_completion_multi(
for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_multi_requests_bloom_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -201,10 +201,10 @@ def test_causal_lm_generate_token_completion_multi(
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -224,11 +224,11 @@ def test_batch_concatenate(
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
):
next_batch_0 = default_bloom_batch
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
_, next_batch_1, _ = default_bloom.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
@ -288,10 +288,10 @@ def test_batch_concatenate(
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
@ -313,10 +313,10 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -337,10 +337,10 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -99,7 +99,9 @@ def test_causal_lm_batch_type(default_causal_lm):
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
generations, next_batch, _ = default_causal_lm.generate_token(
default_causal_lm_batch
)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)
@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion(
):
next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -174,10 +176,10 @@ def test_causal_lm_generate_token_completion_multi(
for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -200,10 +202,10 @@ def test_causal_lm_generate_token_completion_multi(
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -222,11 +224,11 @@ def test_batch_concatenate(
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
):
next_batch_0 = default_causal_lm_batch
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
_, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
@ -285,10 +287,10 @@ def test_batch_concatenate(
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
@ -311,10 +313,10 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -333,10 +335,10 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -55,10 +55,10 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -83,10 +83,10 @@ def test_fim_santacoder_generate_token_completion(
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch = default_seq2seq_lm.generate_token(
generations, next_batch, _ = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -228,11 +228,11 @@ def test_batch_concatenate(
default_multi_requests_seq2seq_lm_batch,
):
next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
# Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
@ -324,10 +324,10 @@ def test_batch_concatenate(
)
for _ in range(3):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
@ -342,7 +342,7 @@ def test_batch_concatenate(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -352,7 +352,7 @@ def test_batch_concatenate(
next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -1,6 +1,5 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import inspect
import time
from dataclasses import dataclass
from opentelemetry import trace
@ -8,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
Tokens,
@ -564,7 +564,8 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
@ -585,6 +586,8 @@ class CausalLM(Model):
torch.log_softmax(logits[:, -1], -1),
)
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
batch.requests,
@ -731,7 +734,9 @@ class CausalLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
return generations, None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
@ -747,4 +752,6 @@ class CausalLM(Model):
# Update past key values
batch.past_key_values = past
return generations, batch
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,6 +1,6 @@
import math
import time
import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import torch.distributed
@ -9,9 +9,10 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import (
Batch,
@ -689,7 +690,7 @@ class FlashCausalLM(Model):
self.dtype,
self.device,
)
_, batch = self.generate_token(batch)
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
@ -799,7 +800,8 @@ class FlashCausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
@ -941,6 +943,8 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
@ -977,7 +981,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens
next_token_texts = []
left = 0
before = stopping_criteria.current_tokens
current_stopped = False
for j in range(index, index + n_accepted_ids):
@ -1092,7 +1095,7 @@ class FlashCausalLM(Model):
generations.append(generation)
# Update values
batch.input_lengths[i] = input_length + n_accepted_ids.item()
batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset
@ -1102,10 +1105,14 @@ class FlashCausalLM(Model):
if stopped:
del batch
# No need to return a batch if we know that all requests stopped
return generations, None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
return generations, batch
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,17 +1,11 @@
import torch
import inspect
import re
from io import BytesIO
import base64
from PIL import Image
import re
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
AutoProcessor,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
ProcessorMixin,
)
@ -670,7 +664,8 @@ class IdeficsCausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: IdeficsCausalLMBatch
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.input_ids.size(1) == 1:
@ -699,6 +694,8 @@ class IdeficsCausalLM(Model):
# Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
start_decode = time.time_ns()
# Results
generations: List[Generation] = []
stopped = True
@ -827,7 +824,9 @@ class IdeficsCausalLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
return generations, None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
@ -847,4 +846,6 @@ class IdeficsCausalLM(Model):
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
return generations, batch
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -65,7 +65,9 @@ class Model(ABC):
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
def generate_token(
self, batch: B
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]:

View File

@ -1,11 +1,12 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.models.types import (
GeneratedText,
@ -613,7 +614,8 @@ class Seq2SeqLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
start = time.time_ns()
if batch.decoder_attention_mask is not None:
# slice to the correct shape
decoder_attention_mask = batch.decoder_attention_mask[
@ -644,6 +646,8 @@ class Seq2SeqLM(Model):
torch.log_softmax(logits[:, -1], -1),
)
start_decode = time.time_ns()
# Finished requests
generations: List[Generation] = []
stopped = True
@ -788,7 +792,9 @@ class Seq2SeqLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
return generations, None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# We don't need input_ids after the prefill forward
batch.input_ids = None
@ -799,4 +805,6 @@ class Seq2SeqLM(Model):
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
return generations, batch
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,6 +1,7 @@
import asyncio
import os
import torch
import time
from grpc import aio
from loguru import logger
@ -76,6 +77,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
async def Prefill(self, request, context):
start = time.time_ns()
if (
self.model.batch_type == IdeficsCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call
@ -91,15 +93,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch = self.model.generate_token(batch)
generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
)
async def Decode(self, request, context):
start = time.time_ns()
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
@ -114,16 +120,23 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise ValueError("All batches are empty")
if len(batches) > 1:
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches)
concat_ns = time.time_ns() - start_concat
else:
batch = batches[0]
concat_ns = None
generations, next_batch = self.model.generate_token(batch)
generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
concat_ns=concat_ns,
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
)

View File

@ -92,7 +92,7 @@ class NextTokenChooser:
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
stop_sequence = re.escape(stop_sequence)
self.regex = re.compile(f".*{stop_sequence}$")
self.regex = re.compile(f"{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):