feat: add more latency metrics in forward (#1346)
This commit is contained in:
parent
44b267ab22
commit
50b495f3d8
|
@ -163,7 +163,7 @@ async fn prefill(
|
||||||
|
|
||||||
// Run prefill
|
// Run prefill
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
||||||
|
|
||||||
// Get latency
|
// Get latency
|
||||||
let latency = start_time.elapsed();
|
let latency = start_time.elapsed();
|
||||||
|
|
|
@ -182,6 +182,12 @@ message PrefillResponse {
|
||||||
repeated Generation generations = 1;
|
repeated Generation generations = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional CachedBatch batch = 2;
|
optional CachedBatch batch = 2;
|
||||||
|
/// Forward elapsed time in nanoseconds
|
||||||
|
uint64 forward_ns = 3;
|
||||||
|
/// Decode elapsed time in nanoseconds
|
||||||
|
uint64 decode_ns = 4;
|
||||||
|
/// Total elapsed time in nanoseconds
|
||||||
|
uint64 total_ns = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DecodeRequest {
|
message DecodeRequest {
|
||||||
|
@ -194,6 +200,14 @@ message DecodeResponse {
|
||||||
repeated Generation generations = 1;
|
repeated Generation generations = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional CachedBatch batch = 2;
|
optional CachedBatch batch = 2;
|
||||||
|
/// Forward elapsed time in nanoseconds
|
||||||
|
uint64 forward_ns = 3;
|
||||||
|
/// Decode elapsed time in nanoseconds
|
||||||
|
uint64 decode_ns = 4;
|
||||||
|
/// Total elapsed time in nanoseconds
|
||||||
|
uint64 total_ns = 5;
|
||||||
|
/// Concatenate elapsed time in nanoseconds
|
||||||
|
optional uint64 concat_ns = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message WarmupRequest {
|
message WarmupRequest {
|
||||||
|
|
|
@ -4,6 +4,7 @@ use crate::pb::generate::v2::*;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
use std::time::Duration;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
@ -157,10 +158,14 @@ impl Client {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((response.generations, response.batch))
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batches
|
/// Generate one token for each request in the given cached batches
|
||||||
|
@ -171,9 +176,52 @@ impl Client {
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
let response = self.stub.decode(request).await?.into_inner();
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
Ok((response.generations, response.batch))
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
DecodeTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PrefillTimings {
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefillTimings {
|
||||||
|
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DecodeTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecodeTimings {
|
||||||
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
concat: concat_ns.map(|v| Duration::from_nanos(v)),
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::client::{DecodeTimings, PrefillTimings};
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
||||||
use crate::{ClientError, Result};
|
use crate::{ClientError, Result};
|
||||||
|
@ -116,49 +117,63 @@ impl ShardedClient {
|
||||||
///
|
///
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
join_all(futures).await.into_iter().collect();
|
join_all(futures).await.into_iter().collect();
|
||||||
merge_generations(results?)
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batches
|
/// Generate one token for each request in the given cached batches
|
||||||
///
|
///
|
||||||
/// Returns Generation for each request in batches
|
/// Returns Generation for each request in batches
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||||
join_all(futures).await.into_iter().collect();
|
join_all(futures).await.into_iter().collect();
|
||||||
merge_generations(results?)
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Merge generations from the different model shards
|
|
||||||
fn merge_generations(
|
|
||||||
mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
|
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
|
||||||
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
|
|
||||||
|
|
||||||
for (mut shard_generations, _) in results.into_iter() {
|
|
||||||
generations.append(&mut shard_generations);
|
|
||||||
}
|
|
||||||
Ok((generations, next_batch))
|
|
||||||
}
|
|
||||||
|
|
|
@ -379,15 +379,20 @@ async fn prefill(
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
// Update health
|
// Update health
|
||||||
generation_health.store(true, Ordering::SeqCst);
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
|
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||||
next_batch
|
next_batch
|
||||||
|
@ -416,15 +421,23 @@ async fn decode(
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
// Update health
|
// Update health
|
||||||
generation_health.store(true, Ordering::SeqCst);
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
|
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||||
|
}
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||||
next_batch
|
next_batch
|
||||||
|
|
|
@ -540,7 +540,7 @@ mod tests {
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_top_n_tokens = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_input_length = 5;
|
let max_input_length = 5;
|
||||||
let max_total_tokens = 6;
|
let max_total_tokens = 106;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
|
@ -600,7 +600,7 @@ mod tests {
|
||||||
let max_stop_sequences = 3;
|
let max_stop_sequences = 3;
|
||||||
let max_top_n_tokens = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_input_length = 5;
|
let max_input_length = 5;
|
||||||
let max_total_tokens = 6;
|
let max_total_tokens = 106;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
|
|
|
@ -103,7 +103,7 @@ def test_causal_lm_batch_type(default_bloom):
|
||||||
|
|
||||||
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
sequence_length = len(default_bloom_batch.all_input_ids[0])
|
||||||
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
|
generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)
|
||||||
|
|
||||||
assert len(generations) == len(default_bloom_batch)
|
assert len(generations) == len(default_bloom_batch)
|
||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
|
@ -153,10 +153,10 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
||||||
next_batch = default_bloom_batch
|
next_batch = default_bloom_batch
|
||||||
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert len(generations) == len(default_bloom_batch)
|
assert len(generations) == len(default_bloom_batch)
|
||||||
|
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
@ -178,10 +178,10 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
for i in range(
|
for i in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
||||||
):
|
):
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert len(generations) == len(default_multi_requests_bloom_batch)
|
assert len(generations) == len(default_multi_requests_bloom_batch)
|
||||||
|
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
|
@ -201,10 +201,10 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
for _ in range(
|
for _ in range(
|
||||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||||
):
|
):
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
@ -224,11 +224,11 @@ def test_batch_concatenate(
|
||||||
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
|
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
|
||||||
):
|
):
|
||||||
next_batch_0 = default_bloom_batch
|
next_batch_0 = default_bloom_batch
|
||||||
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
|
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
|
||||||
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
|
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
|
||||||
|
|
||||||
next_batch_1 = default_multi_requests_bloom_batch
|
next_batch_1 = default_multi_requests_bloom_batch
|
||||||
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
|
_, next_batch_1, _ = default_bloom.generate_token(next_batch_1)
|
||||||
|
|
||||||
# Clone past_key_values before concatenating to compare after,
|
# Clone past_key_values before concatenating to compare after,
|
||||||
# because they are removed from the concatenated batches
|
# because they are removed from the concatenated batches
|
||||||
|
@ -288,10 +288,10 @@ def test_batch_concatenate(
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
||||||
):
|
):
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 3
|
assert len(generations) == 3
|
||||||
|
@ -313,10 +313,10 @@ def test_batch_concatenate(
|
||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 2
|
- 2
|
||||||
):
|
):
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
|
@ -337,10 +337,10 @@ def test_batch_concatenate(
|
||||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 4
|
- 4
|
||||||
):
|
):
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_bloom.generate_token(next_batch)
|
generations, next_batch, _ = default_bloom.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
|
|
@ -99,7 +99,9 @@ def test_causal_lm_batch_type(default_causal_lm):
|
||||||
|
|
||||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
|
||||||
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(
|
||||||
|
default_causal_lm_batch
|
||||||
|
)
|
||||||
|
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
|
@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion(
|
||||||
):
|
):
|
||||||
next_batch = default_causal_lm_batch
|
next_batch = default_causal_lm_batch
|
||||||
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
@ -174,10 +176,10 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
for i in range(
|
for i in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
||||||
):
|
):
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
|
@ -200,10 +202,10 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
for _ in range(
|
for _ in range(
|
||||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||||
):
|
):
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
@ -222,11 +224,11 @@ def test_batch_concatenate(
|
||||||
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
|
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
|
||||||
):
|
):
|
||||||
next_batch_0 = default_causal_lm_batch
|
next_batch_0 = default_causal_lm_batch
|
||||||
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
|
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
|
||||||
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
|
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
|
||||||
|
|
||||||
next_batch_1 = default_multi_requests_causal_lm_batch
|
next_batch_1 = default_multi_requests_causal_lm_batch
|
||||||
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
|
_, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)
|
||||||
|
|
||||||
# Clone past_key_values before concatenating to compare after,
|
# Clone past_key_values before concatenating to compare after,
|
||||||
# because they are removed from the concatenated batches
|
# because they are removed from the concatenated batches
|
||||||
|
@ -285,10 +287,10 @@ def test_batch_concatenate(
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
||||||
):
|
):
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 3
|
assert len(generations) == 3
|
||||||
|
@ -311,10 +313,10 @@ def test_batch_concatenate(
|
||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 2
|
- 2
|
||||||
):
|
):
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
|
@ -333,10 +335,10 @@ def test_batch_concatenate(
|
||||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
- 4
|
- 4
|
||||||
):
|
):
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_causal_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
|
|
@ -55,10 +55,10 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
@ -83,10 +83,10 @@ def test_fim_santacoder_generate_token_completion(
|
||||||
next_batch = batch
|
next_batch = batch
|
||||||
|
|
||||||
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
|
||||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_santacoder.generate_token(next_batch)
|
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
|
|
@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
||||||
|
|
||||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||||
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(
|
||||||
default_seq2seq_lm_batch
|
default_seq2seq_lm_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
|
||||||
):
|
):
|
||||||
next_batch = default_seq2seq_lm_batch
|
next_batch = default_seq2seq_lm_batch
|
||||||
for _ in range(6):
|
for _ in range(6):
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||||
next_batch = default_multi_requests_seq2seq_lm_batch
|
next_batch = default_multi_requests_seq2seq_lm_batch
|
||||||
|
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
|
@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
@ -228,11 +228,11 @@ def test_batch_concatenate(
|
||||||
default_multi_requests_seq2seq_lm_batch,
|
default_multi_requests_seq2seq_lm_batch,
|
||||||
):
|
):
|
||||||
next_batch_0 = default_seq2seq_lm_batch
|
next_batch_0 = default_seq2seq_lm_batch
|
||||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
|
||||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
|
||||||
|
|
||||||
next_batch_1 = default_multi_requests_seq2seq_lm_batch
|
next_batch_1 = default_multi_requests_seq2seq_lm_batch
|
||||||
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
|
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
|
||||||
|
|
||||||
# Copy hidden state because it is removed from the concatenated branches
|
# Copy hidden state because it is removed from the concatenated branches
|
||||||
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
|
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
|
||||||
|
@ -324,10 +324,10 @@ def test_batch_concatenate(
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 3
|
assert len(generations) == 3
|
||||||
|
@ -342,7 +342,7 @@ def test_batch_concatenate(
|
||||||
[next_batch.requests[0].id, next_batch.requests[1].id]
|
[next_batch.requests[0].id, next_batch.requests[1].id]
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
|
@ -352,7 +352,7 @@ def test_batch_concatenate(
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import time
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
@ -8,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
Tokens,
|
Tokens,
|
||||||
|
@ -564,7 +564,8 @@ class CausalLM(Model):
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
|
||||||
|
start = time.time_ns()
|
||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
|
||||||
|
@ -585,6 +586,8 @@ class CausalLM(Model):
|
||||||
torch.log_softmax(logits[:, -1], -1),
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
|
@ -731,7 +734,9 @@ class CausalLM(Model):
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
return generations, None
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
# Slice unused values from prefill
|
# Slice unused values from prefill
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
@ -747,4 +752,6 @@ class CausalLM(Model):
|
||||||
# Update past key values
|
# Update past key values
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
|
|
||||||
return generations, batch
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
import itertools
|
import itertools
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -9,9 +9,10 @@ import numpy as np
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
|
@ -689,7 +690,7 @@ class FlashCausalLM(Model):
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
_, batch = self.generate_token(batch)
|
_, batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||||
|
@ -799,7 +800,8 @@ class FlashCausalLM(Model):
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
||||||
|
start = time.time_ns()
|
||||||
prefill = batch.cu_seqlen_prefill is not None
|
prefill = batch.cu_seqlen_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
|
@ -941,6 +943,8 @@ class FlashCausalLM(Model):
|
||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
next_token_ids = next_input_ids.tolist()
|
next_token_ids = next_input_ids.tolist()
|
||||||
|
accepted_ids = accepted_ids.tolist()
|
||||||
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
|
@ -977,7 +981,6 @@ class FlashCausalLM(Model):
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
before = stopping_criteria.current_tokens
|
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
|
@ -1092,7 +1095,7 @@ class FlashCausalLM(Model):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = input_length + n_accepted_ids.item()
|
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||||
if batch.input_lengths[i] > batch.max_seqlen:
|
if batch.input_lengths[i] > batch.max_seqlen:
|
||||||
batch.max_seqlen = batch.input_lengths[i]
|
batch.max_seqlen = batch.input_lengths[i]
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
|
@ -1102,10 +1105,14 @@ class FlashCausalLM(Model):
|
||||||
if stopped:
|
if stopped:
|
||||||
del batch
|
del batch
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
return generations, None
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
batch.prefill_cu_outlens = None
|
batch.prefill_cu_outlens = None
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
|
|
||||||
return generations, batch
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
|
@ -1,17 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import time
|
||||||
import re
|
|
||||||
from io import BytesIO
|
|
||||||
import base64
|
|
||||||
from PIL import Image
|
|
||||||
import re
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoModelForCausalLM,
|
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
)
|
)
|
||||||
|
@ -670,7 +664,8 @@ class IdeficsCausalLM(Model):
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: IdeficsCausalLMBatch
|
self, batch: IdeficsCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
|
||||||
|
start = time.time_ns()
|
||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
if batch.input_ids.size(1) == 1:
|
if batch.input_ids.size(1) == 1:
|
||||||
|
@ -699,6 +694,8 @@ class IdeficsCausalLM(Model):
|
||||||
# Hardcoded remove image tokens
|
# Hardcoded remove image tokens
|
||||||
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||||
|
|
||||||
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
|
@ -827,7 +824,9 @@ class IdeficsCausalLM(Model):
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
return generations, None
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
# Slice unused values from prefill
|
# Slice unused values from prefill
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
@ -847,4 +846,6 @@ class IdeficsCausalLM(Model):
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
batch.image_hidden_states = image_hidden_states
|
batch.image_hidden_states = image_hidden_states
|
||||||
|
|
||||||
return generations, batch
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
|
@ -65,7 +65,9 @@ class Model(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
|
def generate_token(
|
||||||
|
self, batch: B
|
||||||
|
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B) -> Optional[int]:
|
def warmup(self, batch: B) -> Optional[int]:
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
|
||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
|
@ -613,7 +614,8 @@ class Seq2SeqLM(Model):
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Seq2SeqLMBatch
|
self, batch: Seq2SeqLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
|
||||||
|
start = time.time_ns()
|
||||||
if batch.decoder_attention_mask is not None:
|
if batch.decoder_attention_mask is not None:
|
||||||
# slice to the correct shape
|
# slice to the correct shape
|
||||||
decoder_attention_mask = batch.decoder_attention_mask[
|
decoder_attention_mask = batch.decoder_attention_mask[
|
||||||
|
@ -644,6 +646,8 @@ class Seq2SeqLM(Model):
|
||||||
torch.log_softmax(logits[:, -1], -1),
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
|
@ -788,7 +792,9 @@ class Seq2SeqLM(Model):
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
return generations, None
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
# We don't need input_ids after the prefill forward
|
# We don't need input_ids after the prefill forward
|
||||||
batch.input_ids = None
|
batch.input_ids = None
|
||||||
|
@ -799,4 +805,6 @@ class Seq2SeqLM(Model):
|
||||||
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
batch.padding_right_offset -= 1
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
return generations, batch
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
from grpc import aio
|
from grpc import aio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -76,6 +77,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
|
start = time.time_ns()
|
||||||
if (
|
if (
|
||||||
self.model.batch_type == IdeficsCausalLMBatch
|
self.model.batch_type == IdeficsCausalLMBatch
|
||||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
|
@ -91,15 +93,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch, timings = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
|
forward_ns=timings[0],
|
||||||
|
decode_ns=timings[1],
|
||||||
|
total_ns=time.time_ns() - start,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Decode(self, request, context):
|
async def Decode(self, request, context):
|
||||||
|
start = time.time_ns()
|
||||||
if len(request.batches) == 0:
|
if len(request.batches) == 0:
|
||||||
raise ValueError("Must provide at least one batch")
|
raise ValueError("Must provide at least one batch")
|
||||||
|
|
||||||
|
@ -114,16 +120,23 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
raise ValueError("All batches are empty")
|
raise ValueError("All batches are empty")
|
||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
|
start_concat = time.time_ns()
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
batch = self.model.batch_type.concatenate(batches)
|
||||||
|
concat_ns = time.time_ns() - start_concat
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
concat_ns = None
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch, timings = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
|
concat_ns=concat_ns,
|
||||||
|
forward_ns=timings[0],
|
||||||
|
decode_ns=timings[1],
|
||||||
|
total_ns=time.time_ns() - start,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,7 @@ class NextTokenChooser:
|
||||||
class StopSequenceCriteria:
|
class StopSequenceCriteria:
|
||||||
def __init__(self, stop_sequence: str):
|
def __init__(self, stop_sequence: str):
|
||||||
stop_sequence = re.escape(stop_sequence)
|
stop_sequence = re.escape(stop_sequence)
|
||||||
self.regex = re.compile(f".*{stop_sequence}$")
|
self.regex = re.compile(f"{stop_sequence}$")
|
||||||
|
|
||||||
def __call__(self, output: str) -> bool:
|
def __call__(self, output: str) -> bool:
|
||||||
if self.regex.findall(output):
|
if self.regex.findall(output):
|
||||||
|
|
Loading…
Reference in New Issue