feat: decrease IPC proto size (#367)

Closes #307 #308
This commit is contained in:
OlivierDehaene 2023-05-24 19:19:57 +02:00 committed by GitHub
parent d31562f300
commit 218c9adaa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 108 additions and 88 deletions

View File

@ -1,6 +1,6 @@
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient, Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters, StoppingCriteriaParameters,
}; };
use tokenizers::{Tokenizer, TruncationDirection}; use tokenizers::{Tokenizer, TruncationDirection};
@ -126,7 +126,7 @@ async fn prefill(
batch_size: u32, batch_size: u32,
decode_length: u32, decode_length: u32,
client: &mut ShardedClient, client: &mut ShardedClient,
) -> Result<(Prefill, Batch), ClientError> { ) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests // Create requests
let requests = (0..batch_size) let requests = (0..batch_size)
.map(|id| Request { .map(|id| Request {
@ -180,7 +180,7 @@ async fn prefill(
} }
/// Run a full decode /// Run a full decode
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> { async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
let mut decode_length = 0; let mut decode_length = 0;
let batch_size = batch.size; let batch_size = batch.size;

View File

@ -100,6 +100,17 @@ message Batch {
uint32 max_tokens = 4; uint32 max_tokens = 4;
} }
message CachedBatch {
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
enum FinishReason { enum FinishReason {
FINISH_REASON_LENGTH = 0; FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1; FINISH_REASON_EOS_TOKEN = 1;
@ -140,19 +151,19 @@ message Generation {
/// Is it a special token /// Is it a special token
bool token_is_special = 6; bool token_is_special = 6;
/// Complete generated text /// Complete generated text
GeneratedText generated_text = 7; optional GeneratedText generated_text = 7;
} }
message FilterBatchRequest { message FilterBatchRequest {
/// Batch ID /// Batch ID
uint64 batch_id = 1; uint64 batch_id = 1;
/// Requests to keep /// Requests to keep
repeated Request keep_requests = 2; repeated uint64 request_ids = 2;
} }
message FilterBatchResponse { message FilterBatchResponse {
/// Filtered Batch (cached) /// Filtered Batch (cached)
Batch batch = 1; CachedBatch batch = 1;
} }
@ -165,17 +176,17 @@ message PrefillResponse {
/// Generation /// Generation
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; optional CachedBatch batch = 2;
} }
message DecodeRequest { message DecodeRequest {
/// Cached batches /// Cached batches
repeated Batch batches = 1; repeated CachedBatch batches = 1;
} }
message DecodeResponse { message DecodeResponse {
/// Decodes /// Decodes
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; optional CachedBatch batch = 2;
} }

View File

@ -83,11 +83,11 @@ impl Client {
pub async fn filter_batch( pub async fn filter_batch(
&mut self, &mut self,
batch_id: u64, batch_id: u64,
keep_requests: Vec<Request>, request_ids: Vec<u64>,
) -> Result<Option<Batch>> { ) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest { let request = tonic::Request::new(FilterBatchRequest {
batch_id, batch_id,
keep_requests, request_ids,
}) })
.inject_context(); .inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
@ -99,7 +99,10 @@ impl Client {
/// Returns Generation for each request in batch /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch)) Ok((response.generations, response.batch))
@ -112,8 +115,8 @@ impl Client {
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))] #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<Batch>)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner(); let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch)) Ok((response.generations, response.batch))

View File

@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StoppingCriteriaParameters, PrefillTokens, Request, StoppingCriteriaParameters,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -1,5 +1,5 @@
/// Multi shard Client /// Multi shard Client
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
use crate::{ClientError, Result}; use crate::{ClientError, Result};
use futures::future::join_all; use futures::future::join_all;
use tonic::transport::Uri; use tonic::transport::Uri;
@ -76,12 +76,12 @@ impl ShardedClient {
pub async fn filter_batch( pub async fn filter_batch(
&mut self, &mut self,
batch_id: u64, batch_id: u64,
keep_requests: Vec<Request>, request_ids: Vec<u64>,
) -> Result<Option<Batch>> { ) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone()))) .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect(); .collect();
// all shards return the same message // all shards return the same message
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap()
@ -92,13 +92,16 @@ impl ShardedClient {
/// Returns Generation for each request in batch /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone())))
.collect(); .collect();
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
join_all(futures).await.into_iter().collect(); join_all(futures).await.into_iter().collect();
merge_generations(results?) merge_generations(results?)
} }
@ -110,14 +113,14 @@ impl ShardedClient {
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))] #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<Batch>)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.decode(batches.clone()))) .map(|client| Box::pin(client.decode(batches.clone())))
.collect(); .collect();
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
join_all(futures).await.into_iter().collect(); join_all(futures).await.into_iter().collect();
merge_generations(results?) merge_generations(results?)
} }
@ -125,8 +128,8 @@ impl ShardedClient {
/// Merge generations from the different model shards /// Merge generations from the different model shards
fn merge_generations( fn merge_generations(
mut results: Vec<(Vec<Generation>, Option<Batch>)>, mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
) -> Result<(Vec<Generation>, Option<Batch>)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?; let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
for (mut shard_generations, _) in results.into_iter() { for (mut shard_generations, _) in results.into_iter() {

View File

@ -12,7 +12,7 @@ use std::sync::{
Arc, Arc,
}; };
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
@ -352,7 +352,7 @@ async fn prefill(
batch: Batch, batch: Batch,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>, generation_health: &Arc<AtomicBool>,
) -> Option<Batch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
@ -386,10 +386,10 @@ async fn prefill(
#[instrument(skip_all)] #[instrument(skip_all)]
async fn decode( async fn decode(
client: &mut ShardedClient, client: &mut ShardedClient,
batches: Vec<Batch>, batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>, generation_health: &Arc<AtomicBool>,
) -> Option<Batch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
@ -425,9 +425,9 @@ async fn decode(
#[instrument(skip_all)] #[instrument(skip_all)]
async fn filter_batch( async fn filter_batch(
client: &mut ShardedClient, client: &mut ShardedClient,
next_batch: Option<Batch>, next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>, entries: &IntMap<u64, Entry>,
) -> Option<Batch> { ) -> Option<CachedBatch> {
let mut batch = next_batch?; let mut batch = next_batch?;
// No need to filter // No need to filter
@ -438,9 +438,9 @@ async fn filter_batch(
let id = batch.id; let id = batch.id;
// Retain only requests that are still in entries // Retain only requests that are still in entries
batch.requests.retain(|r| entries.contains_key(&r.id)); batch.request_ids.retain(|id| entries.contains_key(id));
if batch.requests.is_empty() { if batch.request_ids.is_empty() {
// All requests have been filtered out // All requests have been filtered out
// Next batch is now empty // Next batch is now empty
// Clear it from the Python shards cache // Clear it from the Python shards cache
@ -450,7 +450,7 @@ async fn filter_batch(
} else { } else {
// Filter Python shard cache // Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.requests).await.unwrap() client.filter_batch(id, batch.request_ids).await.unwrap()
} }
} }

View File

@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
# Copy stopping_criterias before filtering # Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
next_batch = next_batch.filter([next_batch.requests[0]]) next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range( for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
@ -286,7 +286,7 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
for _ in range( for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens default_bloom_batch.stopping_criterias[0].max_new_tokens
@ -309,7 +309,7 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[1]]) next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range( for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
default_multi_requests_causal_lm_batch.stopping_criterias.copy() default_multi_requests_causal_lm_batch.stopping_criterias.copy()
) )
next_batch = next_batch.filter([next_batch.requests[0]]) next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range( for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
@ -285,7 +285,7 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
for _ in range( for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens default_causal_lm_batch.stopping_criterias[0].max_new_tokens
@ -306,7 +306,7 @@ def test_batch_concatenate(
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
next_batch = next_batch.filter([next_batch.requests[1]]) next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range( for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
) )
assert generations[1].generated_text.generated_tokens == 5 assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0]]) next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
@ -323,7 +323,7 @@ def test_batch_concatenate(
) )
assert generations[2].generated_text.generated_tokens == 5 assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]]) next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None assert next_batch is not None
@ -333,7 +333,7 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1]]) next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch) generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None assert next_batch is None

View File

@ -53,10 +53,10 @@ class CausalLMBatch(Batch):
# Past metadata # Past metadata
keys_head_dim_last: bool = True keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.Batch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@ -143,16 +143,17 @@ class CausalLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]: def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
if len(requests) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
if len(requests) == len(self): if len(request_ids) == len(self):
return self return self
keep_indices = [] keep_indices = []
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
requests = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -165,11 +166,12 @@ class CausalLMBatch(Batch):
total_remaining_decode_tokens = 0 total_remaining_decode_tokens = 0
new_padding_right_offset = 0 new_padding_right_offset = 0
for i, r in enumerate(requests): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[r.id] idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[r.id] = i requests_idx_mapping[request_id] = i
keep_indices.append(idx) keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx]) read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
@ -220,7 +222,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :] layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values del past_values
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests self.requests = requests
self.requests_idx_mapping = requests_idx_mapping self.requests_idx_mapping = requests_idx_mapping

View File

@ -62,10 +62,10 @@ class FlashCausalLMBatch(Batch):
# Maximum number of tokens this batch will grow to # Maximum number of tokens this batch will grow to
max_tokens: int max_tokens: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.Batch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@ -161,14 +161,14 @@ class FlashCausalLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch": def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
if len(requests) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same # We assume that if len(requests) == len(self) then the requests are the same
if len(requests) == len(self): if len(request_ids) == len(self):
return self return self
single_request = len(requests) == 1 single_request = len(request_ids) == 1
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
@ -176,16 +176,17 @@ class FlashCausalLMBatch(Batch):
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
input_ids = self.input_ids.new_empty(len(requests)) input_ids = self.input_ids.new_empty(len(request_ids))
position_ids = self.position_ids.new_empty(len(requests)) position_ids = self.position_ids.new_empty(len(request_ids))
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32) cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(
0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32 0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
) )
max_seqlen = 0 max_seqlen = 0
past_key_values = [] past_key_values = []
requests = []
all_input_ids = [] all_input_ids = []
all_input_ids_tensor = [] all_input_ids_tensor = []
@ -198,9 +199,11 @@ class FlashCausalLMBatch(Batch):
max_tokens = 0 max_tokens = 0
for i, r in enumerate(requests): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[r.id] idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[r.id] = i requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
# Get length # Get length
request_input_length = self.input_lengths[idx] request_input_length = self.input_lengths[idx]

View File

@ -57,11 +57,11 @@ class Seq2SeqLMBatch(Batch):
# Maximum number of tokens this batch will grow to # Maximum number of tokens this batch will grow to
max_tokens: int max_tokens: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.CachedBatch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
return generate_pb2.Batch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@ -152,18 +152,17 @@ class Seq2SeqLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
self, requests: List[generate_pb2.Request] if len(request_ids) == 0:
) -> Optional["Seq2SeqLMBatch"]:
if len(requests) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
if len(requests) == len(self): if len(request_ids) == len(self):
return self return self
keep_indices = [] keep_indices = []
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
requests = []
input_lengths = [] input_lengths = []
decoder_input_lengths = [] decoder_input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -180,11 +179,12 @@ class Seq2SeqLMBatch(Batch):
total_remaining_decode_tokens = 0 total_remaining_decode_tokens = 0
for i, r in enumerate(requests): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[r.id] idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[r.id] = i requests_idx_mapping[request_id] = i
keep_indices.append(idx) keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx]) read_offsets.append(self.read_offsets[idx])
@ -239,7 +239,7 @@ class Seq2SeqLMBatch(Batch):
layer[3] = layer[3][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:]
max_tokens = ( max_tokens = (
len(requests) * (max_input_length + max_decoder_input_length) len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens + remaining_decode_tokens
) )

View File

@ -12,7 +12,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC): class Batch(ABC):
@abstractmethod @abstractmethod
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.CachedBatch:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -26,7 +26,7 @@ class Batch(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def filter(self, requests: List[generate_pb2.Request]) -> "Batch": def filter(self, request_ids: List[int]) -> "Batch":
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod

View File

@ -42,15 +42,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache.delete(request.id) self.cache.delete(request.id)
else: else:
self.cache.clear() self.cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context): async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id) batch = self.cache.pop(request.batch_id)
if batch is None: if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.") raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.keep_requests) filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch) self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())