parent
d31562f300
commit
218c9adaa5
|
@ -1,6 +1,6 @@
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
||||||
StoppingCriteriaParameters,
|
StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use tokenizers::{Tokenizer, TruncationDirection};
|
use tokenizers::{Tokenizer, TruncationDirection};
|
||||||
|
@ -126,7 +126,7 @@ async fn prefill(
|
||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
) -> Result<(Prefill, Batch), ClientError> {
|
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||||
// Create requests
|
// Create requests
|
||||||
let requests = (0..batch_size)
|
let requests = (0..batch_size)
|
||||||
.map(|id| Request {
|
.map(|id| Request {
|
||||||
|
@ -180,7 +180,7 @@ async fn prefill(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run a full decode
|
/// Run a full decode
|
||||||
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
||||||
let mut decode_length = 0;
|
let mut decode_length = 0;
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
|
|
||||||
|
|
|
@ -100,6 +100,17 @@ message Batch {
|
||||||
uint32 max_tokens = 4;
|
uint32 max_tokens = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message CachedBatch {
|
||||||
|
/// Batch ID
|
||||||
|
uint64 id = 1;
|
||||||
|
/// Individual requests ids
|
||||||
|
repeated uint64 request_ids = 2;
|
||||||
|
/// Batch size (==len(requests))
|
||||||
|
uint32 size = 3;
|
||||||
|
/// Maximum number of tokens this batch will grow to
|
||||||
|
uint32 max_tokens = 4;
|
||||||
|
}
|
||||||
|
|
||||||
enum FinishReason {
|
enum FinishReason {
|
||||||
FINISH_REASON_LENGTH = 0;
|
FINISH_REASON_LENGTH = 0;
|
||||||
FINISH_REASON_EOS_TOKEN = 1;
|
FINISH_REASON_EOS_TOKEN = 1;
|
||||||
|
@ -140,19 +151,19 @@ message Generation {
|
||||||
/// Is it a special token
|
/// Is it a special token
|
||||||
bool token_is_special = 6;
|
bool token_is_special = 6;
|
||||||
/// Complete generated text
|
/// Complete generated text
|
||||||
GeneratedText generated_text = 7;
|
optional GeneratedText generated_text = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchRequest {
|
message FilterBatchRequest {
|
||||||
/// Batch ID
|
/// Batch ID
|
||||||
uint64 batch_id = 1;
|
uint64 batch_id = 1;
|
||||||
/// Requests to keep
|
/// Requests to keep
|
||||||
repeated Request keep_requests = 2;
|
repeated uint64 request_ids = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchResponse {
|
message FilterBatchResponse {
|
||||||
/// Filtered Batch (cached)
|
/// Filtered Batch (cached)
|
||||||
Batch batch = 1;
|
CachedBatch batch = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,17 +176,17 @@ message PrefillResponse {
|
||||||
/// Generation
|
/// Generation
|
||||||
repeated Generation generations = 1;
|
repeated Generation generations = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional CachedBatch batch = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DecodeRequest {
|
message DecodeRequest {
|
||||||
/// Cached batches
|
/// Cached batches
|
||||||
repeated Batch batches = 1;
|
repeated CachedBatch batches = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DecodeResponse {
|
message DecodeResponse {
|
||||||
/// Decodes
|
/// Decodes
|
||||||
repeated Generation generations = 1;
|
repeated Generation generations = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional CachedBatch batch = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,11 +83,11 @@ impl Client {
|
||||||
pub async fn filter_batch(
|
pub async fn filter_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch_id: u64,
|
batch_id: u64,
|
||||||
keep_requests: Vec<Request>,
|
request_ids: Vec<u64>,
|
||||||
) -> Result<Option<Batch>> {
|
) -> Result<Option<CachedBatch>> {
|
||||||
let request = tonic::Request::new(FilterBatchRequest {
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
batch_id,
|
batch_id,
|
||||||
keep_requests,
|
request_ids,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
|
@ -99,7 +99,10 @@ impl Client {
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((response.generations, response.batch))
|
Ok((response.generations, response.batch))
|
||||||
|
@ -112,8 +115,8 @@ impl Client {
|
||||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
let response = self.stub.decode(request).await?.into_inner();
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
Ok((response.generations, response.batch))
|
Ok((response.generations, response.batch))
|
||||||
|
|
|
@ -9,8 +9,8 @@ pub use client::Client;
|
||||||
pub use pb::generate::v1::HealthResponse;
|
pub use pb::generate::v1::HealthResponse;
|
||||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||||
Request, StoppingCriteriaParameters,
|
PrefillTokens, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
||||||
use crate::{ClientError, Result};
|
use crate::{ClientError, Result};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
|
@ -76,12 +76,12 @@ impl ShardedClient {
|
||||||
pub async fn filter_batch(
|
pub async fn filter_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch_id: u64,
|
batch_id: u64,
|
||||||
keep_requests: Vec<Request>,
|
request_ids: Vec<u64>,
|
||||||
) -> Result<Option<Batch>> {
|
) -> Result<Option<CachedBatch>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone())))
|
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
// all shards return the same message
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
|
@ -92,13 +92,16 @@ impl ShardedClient {
|
||||||
/// Returns Generation for each request in batch
|
/// Returns Generation for each request in batch
|
||||||
/// and the next cached batch
|
/// and the next cached batch
|
||||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
|
||||||
join_all(futures).await.into_iter().collect();
|
join_all(futures).await.into_iter().collect();
|
||||||
merge_generations(results?)
|
merge_generations(results?)
|
||||||
}
|
}
|
||||||
|
@ -110,14 +113,14 @@ impl ShardedClient {
|
||||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||||
pub async fn decode(
|
pub async fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
|
||||||
join_all(futures).await.into_iter().collect();
|
join_all(futures).await.into_iter().collect();
|
||||||
merge_generations(results?)
|
merge_generations(results?)
|
||||||
}
|
}
|
||||||
|
@ -125,8 +128,8 @@ impl ShardedClient {
|
||||||
|
|
||||||
/// Merge generations from the different model shards
|
/// Merge generations from the different model shards
|
||||||
fn merge_generations(
|
fn merge_generations(
|
||||||
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
|
mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
|
||||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||||
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
|
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
for (mut shard_generations, _) in results.into_iter() {
|
for (mut shard_generations, _) in results.into_iter() {
|
||||||
|
|
|
@ -12,7 +12,7 @@ use std::sync::{
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
|
@ -352,7 +352,7 @@ async fn prefill(
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
generation_health: &Arc<AtomicBool>,
|
generation_health: &Arc<AtomicBool>,
|
||||||
) -> Option<Batch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||||
|
@ -386,10 +386,10 @@ async fn prefill(
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn decode(
|
async fn decode(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
generation_health: &Arc<AtomicBool>,
|
generation_health: &Arc<AtomicBool>,
|
||||||
) -> Option<Batch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||||
|
@ -425,9 +425,9 @@ async fn decode(
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn filter_batch(
|
async fn filter_batch(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
next_batch: Option<Batch>,
|
next_batch: Option<CachedBatch>,
|
||||||
entries: &IntMap<u64, Entry>,
|
entries: &IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<CachedBatch> {
|
||||||
let mut batch = next_batch?;
|
let mut batch = next_batch?;
|
||||||
|
|
||||||
// No need to filter
|
// No need to filter
|
||||||
|
@ -438,9 +438,9 @@ async fn filter_batch(
|
||||||
let id = batch.id;
|
let id = batch.id;
|
||||||
|
|
||||||
// Retain only requests that are still in entries
|
// Retain only requests that are still in entries
|
||||||
batch.requests.retain(|r| entries.contains_key(&r.id));
|
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||||
|
|
||||||
if batch.requests.is_empty() {
|
if batch.request_ids.is_empty() {
|
||||||
// All requests have been filtered out
|
// All requests have been filtered out
|
||||||
// Next batch is now empty
|
// Next batch is now empty
|
||||||
// Clear it from the Python shards cache
|
// Clear it from the Python shards cache
|
||||||
|
@ -450,7 +450,7 @@ async fn filter_batch(
|
||||||
} else {
|
} else {
|
||||||
// Filter Python shard cache
|
// Filter Python shard cache
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
client.filter_batch(id, batch.requests).await.unwrap()
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
# Copy stopping_criterias before filtering
|
# Copy stopping_criterias before filtering
|
||||||
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||||
|
@ -286,7 +286,7 @@ def test_batch_concatenate(
|
||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -309,7 +309,7 @@ def test_batch_concatenate(
|
||||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
|
|
@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||||
|
@ -285,7 +285,7 @@ def test_batch_concatenate(
|
||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
@ -306,7 +306,7 @@ def test_batch_concatenate(
|
||||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
|
|
|
@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
||||||
)
|
)
|
||||||
assert generations[1].generated_text.generated_tokens == 5
|
assert generations[1].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
|
@ -323,7 +323,7 @@ def test_batch_concatenate(
|
||||||
)
|
)
|
||||||
assert generations[2].generated_text.generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
@ -333,7 +333,7 @@ def test_batch_concatenate(
|
||||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generations[0].generated_text.generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||||
|
|
||||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
|
@ -53,10 +53,10 @@ class CausalLMBatch(Batch):
|
||||||
# Past metadata
|
# Past metadata
|
||||||
keys_head_dim_last: bool = True
|
keys_head_dim_last: bool = True
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
@ -143,16 +143,17 @@ class CausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
|
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
|
||||||
if len(requests) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
if len(requests) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
keep_indices = []
|
keep_indices = []
|
||||||
|
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
@ -165,11 +166,12 @@ class CausalLMBatch(Batch):
|
||||||
total_remaining_decode_tokens = 0
|
total_remaining_decode_tokens = 0
|
||||||
new_padding_right_offset = 0
|
new_padding_right_offset = 0
|
||||||
|
|
||||||
for i, r in enumerate(requests):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[r.id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[request_id] = i
|
||||||
keep_indices.append(idx)
|
keep_indices.append(idx)
|
||||||
|
|
||||||
|
requests.append(self.requests[idx])
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
@ -220,7 +222,7 @@ class CausalLMBatch(Batch):
|
||||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
|
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||||
|
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
self.requests_idx_mapping = requests_idx_mapping
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
|
|
|
@ -62,10 +62,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Maximum number of tokens this batch will grow to
|
# Maximum number of tokens this batch will grow to
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
@ -161,14 +161,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
|
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||||
if len(requests) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
# We assume that if len(requests) == len(self) then the requests are the same
|
||||||
if len(requests) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
single_request = len(requests) == 1
|
single_request = len(request_ids) == 1
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
|
@ -176,16 +176,17 @@ class FlashCausalLMBatch(Batch):
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
input_ids = self.input_ids.new_empty(len(requests))
|
input_ids = self.input_ids.new_empty(len(request_ids))
|
||||||
position_ids = self.position_ids.new_empty(len(requests))
|
position_ids = self.position_ids.new_empty(len(request_ids))
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
|
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
|
||||||
cu_seqlens_q = torch.arange(
|
cu_seqlens_q = torch.arange(
|
||||||
0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
|
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
|
|
||||||
|
requests = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
|
|
||||||
|
@ -198,9 +199,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
|
||||||
for i, r in enumerate(requests):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[r.id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[request_id] = i
|
||||||
|
|
||||||
|
requests.append(self.requests[idx])
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
|
|
|
@ -57,11 +57,11 @@ class Seq2SeqLMBatch(Batch):
|
||||||
# Maximum number of tokens this batch will grow to
|
# Maximum number of tokens this batch will grow to
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
@ -152,18 +152,17 @@ class Seq2SeqLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
|
||||||
self, requests: List[generate_pb2.Request]
|
if len(request_ids) == 0:
|
||||||
) -> Optional["Seq2SeqLMBatch"]:
|
|
||||||
if len(requests) == 0:
|
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
if len(requests) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
keep_indices = []
|
keep_indices = []
|
||||||
|
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
|
@ -180,11 +179,12 @@ class Seq2SeqLMBatch(Batch):
|
||||||
|
|
||||||
total_remaining_decode_tokens = 0
|
total_remaining_decode_tokens = 0
|
||||||
|
|
||||||
for i, r in enumerate(requests):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[r.id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[request_id] = i
|
||||||
keep_indices.append(idx)
|
keep_indices.append(idx)
|
||||||
|
|
||||||
|
requests.append(self.requests[idx])
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||||
|
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
len(requests) * (max_input_length + max_decoder_input_length)
|
len(request_ids) * (max_input_length + max_decoder_input_length)
|
||||||
+ remaining_decode_tokens
|
+ remaining_decode_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
class Batch(ABC):
|
class Batch(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -26,7 +26,7 @@ class Batch(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def filter(self, requests: List[generate_pb2.Request]) -> "Batch":
|
def filter(self, request_ids: List[int]) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -42,15 +42,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
self.cache.delete(request.id)
|
self.cache.delete(request.id)
|
||||||
else:
|
else:
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return generate_pb2.ClearCacheResponse()
|
return generate_pb2.ClearCacheResponse()
|
||||||
|
|
||||||
async def FilterBatch(self, request, context):
|
async def FilterBatch(self, request, context):
|
||||||
batch = self.cache.pop(request.batch_id)
|
batch = self.cache.pop(request.batch_id)
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||||
filtered_batch = batch.filter(request.keep_requests)
|
filtered_batch = batch.filter(request.request_ids)
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
Loading…
Reference in New Issue