feat: decrease IPC proto size (#367)

Closes #307 #308
This commit is contained in:
OlivierDehaene 2023-05-24 19:19:57 +02:00 committed by GitHub
parent d31562f300
commit 218c9adaa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 108 additions and 88 deletions

View File

@ -1,6 +1,6 @@
use std::time::{Duration, Instant};
use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use tokenizers::{Tokenizer, TruncationDirection};
@ -126,7 +126,7 @@ async fn prefill(
batch_size: u32,
decode_length: u32,
client: &mut ShardedClient,
) -> Result<(Prefill, Batch), ClientError> {
) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests
let requests = (0..batch_size)
.map(|id| Request {
@ -180,7 +180,7 @@ async fn prefill(
}
/// Run a full decode
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
let mut decode_length = 0;
let batch_size = batch.size;

View File

@ -100,6 +100,17 @@ message Batch {
uint32 max_tokens = 4;
}
message CachedBatch {
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
@ -140,19 +151,19 @@ message Generation {
/// Is it a special token
bool token_is_special = 6;
/// Complete generated text
GeneratedText generated_text = 7;
optional GeneratedText generated_text = 7;
}
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated Request keep_requests = 2;
repeated uint64 request_ids = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
Batch batch = 1;
CachedBatch batch = 1;
}
@ -165,17 +176,17 @@ message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional Batch batch = 2;
optional CachedBatch batch = 2;
}
message DecodeRequest {
/// Cached batches
repeated Batch batches = 1;
repeated CachedBatch batches = 1;
}
message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional Batch batch = 2;
optional CachedBatch batch = 2;
}

View File

@ -83,11 +83,11 @@ impl Client {
pub async fn filter_batch(
&mut self,
batch_id: u64,
keep_requests: Vec<Request>,
) -> Result<Option<Batch>> {
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
keep_requests,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
@ -99,7 +99,10 @@ impl Client {
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch))
@ -112,8 +115,8 @@ impl Client {
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch))

View File

@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
Request, StoppingCriteriaParameters,
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
PrefillTokens, Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -1,5 +1,5 @@
/// Multi shard Client
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
use crate::{ClientError, Result};
use futures::future::join_all;
use tonic::transport::Uri;
@ -76,12 +76,12 @@ impl ShardedClient {
pub async fn filter_batch(
&mut self,
batch_id: u64,
keep_requests: Vec<Request>,
) -> Result<Option<Batch>> {
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone())))
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
@ -92,13 +92,16 @@ impl ShardedClient {
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
join_all(futures).await.into_iter().collect();
merge_generations(results?)
}
@ -110,14 +113,14 @@ impl ShardedClient {
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
join_all(futures).await.into_iter().collect();
merge_generations(results?)
}
@ -125,8 +128,8 @@ impl ShardedClient {
/// Merge generations from the different model shards
fn merge_generations(
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
for (mut shard_generations, _) in results.into_iter() {

View File

@ -12,7 +12,7 @@ use std::sync::{
Arc,
};
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
@ -352,7 +352,7 @@ async fn prefill(
batch: Batch,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<Batch> {
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
@ -386,10 +386,10 @@ async fn prefill(
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<Batch>,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<Batch> {
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
@ -425,9 +425,9 @@ async fn decode(
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<Batch>,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<Batch> {
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
@ -438,9 +438,9 @@ async fn filter_batch(
let id = batch.id;
// Retain only requests that are still in entries
batch.requests.retain(|r| entries.contains_key(&r.id));
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.requests.is_empty() {
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
@ -450,7 +450,7 @@ async fn filter_batch(
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.requests).await.unwrap()
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}

View File

@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
next_batch = next_batch.filter([next_batch.requests[0]])
next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
@ -286,7 +286,7 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens
@ -309,7 +309,7 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1]])
next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter([next_batch.requests[0]])
next_batch = next_batch.filter([next_batch.requests[0].id])
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
@ -285,7 +285,7 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
@ -306,7 +306,7 @@ def test_batch_concatenate(
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[1]])
next_batch = next_batch.filter([next_batch.requests[1].id])
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
)
assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0]])
next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
@ -323,7 +323,7 @@ def test_batch_concatenate(
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
@ -333,7 +333,7 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter([next_batch.requests[1]])
next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None

View File

@ -53,10 +53,10 @@ class CausalLMBatch(Batch):
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
requests=self.requests,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@ -143,16 +143,17 @@ class CausalLMBatch(Batch):
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
if len(requests) == 0:
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
@ -165,11 +166,12 @@ class CausalLMBatch(Batch):
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
@ -220,7 +222,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping

View File

@ -62,10 +62,10 @@ class FlashCausalLMBatch(Batch):
# Maximum number of tokens this batch will grow to
max_tokens: int
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
requests=self.requests,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@ -161,14 +161,14 @@ class FlashCausalLMBatch(Batch):
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
if len(requests) == 0:
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(requests) == len(self):
if len(request_ids) == len(self):
return self
single_request = len(requests) == 1
single_request = len(request_ids) == 1
# Cumulative length
cumulative_length = 0
@ -176,16 +176,17 @@ class FlashCausalLMBatch(Batch):
# New values after filtering
requests_idx_mapping = {}
input_ids = self.input_ids.new_empty(len(requests))
position_ids = self.position_ids.new_empty(len(requests))
input_ids = self.input_ids.new_empty(len(request_ids))
position_ids = self.position_ids.new_empty(len(request_ids))
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
cu_seqlens_q = torch.arange(
0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
)
max_seqlen = 0
past_key_values = []
requests = []
all_input_ids = []
all_input_ids_tensor = []
@ -198,9 +199,11 @@ class FlashCausalLMBatch(Batch):
max_tokens = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
# Get length
request_input_length = self.input_lengths[idx]

View File

@ -57,11 +57,11 @@ class Seq2SeqLMBatch(Batch):
# Maximum number of tokens this batch will grow to
max_tokens: int
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch(
def to_pb(self) -> generate_pb2.CachedBatch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
return generate_pb2.CachedBatch(
id=self.batch_id,
requests=self.requests,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@ -152,18 +152,17 @@ class Seq2SeqLMBatch(Batch):
)
@tracer.start_as_current_span("filter")
def filter(
self, requests: List[generate_pb2.Request]
) -> Optional["Seq2SeqLMBatch"]:
if len(requests) == 0:
def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
decoder_input_lengths = []
prefix_offsets = []
@ -180,11 +179,12 @@ class Seq2SeqLMBatch(Batch):
total_remaining_decode_tokens = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
@ -239,7 +239,7 @@ class Seq2SeqLMBatch(Batch):
layer[3] = layer[3][keep_indices, :, -max_input_length:]
max_tokens = (
len(requests) * (max_input_length + max_decoder_input_length)
len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
)

View File

@ -12,7 +12,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC):
@abstractmethod
def to_pb(self) -> generate_pb2.Batch:
def to_pb(self) -> generate_pb2.CachedBatch:
raise NotImplementedError
@classmethod
@ -26,7 +26,7 @@ class Batch(ABC):
raise NotImplementedError
@abstractmethod
def filter(self, requests: List[generate_pb2.Request]) -> "Batch":
def filter(self, request_ids: List[int]) -> "Batch":
raise NotImplementedError
@classmethod

View File

@ -42,15 +42,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache.delete(request.id)
else:
self.cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id)
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.keep_requests)
filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())