parent
d31562f300
commit
218c9adaa5
|
@ -1,6 +1,6 @@
|
|||
use std::time::{Duration, Instant};
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
||||
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
use tokenizers::{Tokenizer, TruncationDirection};
|
||||
|
@ -126,7 +126,7 @@ async fn prefill(
|
|||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<(Prefill, Batch), ClientError> {
|
||||
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||
// Create requests
|
||||
let requests = (0..batch_size)
|
||||
.map(|id| Request {
|
||||
|
@ -180,7 +180,7 @@ async fn prefill(
|
|||
}
|
||||
|
||||
/// Run a full decode
|
||||
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
||||
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
||||
let mut decode_length = 0;
|
||||
let batch_size = batch.size;
|
||||
|
||||
|
|
|
@ -100,6 +100,17 @@ message Batch {
|
|||
uint32 max_tokens = 4;
|
||||
}
|
||||
|
||||
message CachedBatch {
|
||||
/// Batch ID
|
||||
uint64 id = 1;
|
||||
/// Individual requests ids
|
||||
repeated uint64 request_ids = 2;
|
||||
/// Batch size (==len(requests))
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
}
|
||||
|
||||
enum FinishReason {
|
||||
FINISH_REASON_LENGTH = 0;
|
||||
FINISH_REASON_EOS_TOKEN = 1;
|
||||
|
@ -140,19 +151,19 @@ message Generation {
|
|||
/// Is it a special token
|
||||
bool token_is_special = 6;
|
||||
/// Complete generated text
|
||||
GeneratedText generated_text = 7;
|
||||
optional GeneratedText generated_text = 7;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
/// Batch ID
|
||||
uint64 batch_id = 1;
|
||||
/// Requests to keep
|
||||
repeated Request keep_requests = 2;
|
||||
repeated uint64 request_ids = 2;
|
||||
}
|
||||
|
||||
message FilterBatchResponse {
|
||||
/// Filtered Batch (cached)
|
||||
Batch batch = 1;
|
||||
CachedBatch batch = 1;
|
||||
}
|
||||
|
||||
|
||||
|
@ -165,17 +176,17 @@ message PrefillResponse {
|
|||
/// Generation
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
optional CachedBatch batch = 2;
|
||||
}
|
||||
|
||||
message DecodeRequest {
|
||||
/// Cached batches
|
||||
repeated Batch batches = 1;
|
||||
repeated CachedBatch batches = 1;
|
||||
}
|
||||
|
||||
message DecodeResponse {
|
||||
/// Decodes
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
optional CachedBatch batch = 2;
|
||||
}
|
||||
|
|
|
@ -83,11 +83,11 @@ impl Client {
|
|||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
keep_requests: Vec<Request>,
|
||||
) -> Result<Option<Batch>> {
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
keep_requests,
|
||||
request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
|
@ -99,7 +99,10 @@ impl Client {
|
|||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((response.generations, response.batch))
|
||||
|
@ -112,8 +115,8 @@ impl Client {
|
|||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((response.generations, response.batch))
|
||||
|
|
|
@ -9,8 +9,8 @@ pub use client::Client;
|
|||
pub use pb::generate::v1::HealthResponse;
|
||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||
Request, StoppingCriteriaParameters,
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||
PrefillTokens, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/// Multi shard Client
|
||||
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
||||
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
|
@ -76,12 +76,12 @@ impl ShardedClient {
|
|||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
keep_requests: Vec<Request>,
|
||||
) -> Result<Option<Batch>> {
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone())))
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
|
@ -92,13 +92,16 @@ impl ShardedClient {
|
|||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
merge_generations(results?)
|
||||
}
|
||||
|
@ -110,14 +113,14 @@ impl ShardedClient {
|
|||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
merge_generations(results?)
|
||||
}
|
||||
|
@ -125,8 +128,8 @@ impl ShardedClient {
|
|||
|
||||
/// Merge generations from the different model shards
|
||||
fn merge_generations(
|
||||
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
|
||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
||||
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
for (mut shard_generations, _) in results.into_iter() {
|
||||
|
|
|
@ -12,7 +12,7 @@ use std::sync::{
|
|||
Arc,
|
||||
};
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
|
@ -352,7 +352,7 @@ async fn prefill(
|
|||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<Batch> {
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||
|
@ -386,10 +386,10 @@ async fn prefill(
|
|||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<Batch>,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<Batch> {
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||
|
@ -425,9 +425,9 @@ async fn decode(
|
|||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<Batch>,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
|
@ -438,9 +438,9 @@ async fn filter_batch(
|
|||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.requests.retain(|r| entries.contains_key(&r.id));
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.requests.is_empty() {
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
|
@ -450,7 +450,7 @@ async fn filter_batch(
|
|||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.requests).await.unwrap()
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||
|
||||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
|
@ -286,7 +286,7 @@ def test_batch_concatenate(
|
|||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
|
||||
for _ in range(
|
||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
|
@ -309,7 +309,7 @@ def test_batch_concatenate(
|
|||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
|
|
|
@ -178,7 +178,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||
|
||||
for _ in range(
|
||||
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
|
||||
|
@ -285,7 +285,7 @@ def test_batch_concatenate(
|
|||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
|
||||
for _ in range(
|
||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
|
@ -306,7 +306,7 @@ def test_batch_concatenate(
|
|||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
|
|
|
@ -190,7 +190,7 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||
)
|
||||
assert generations[1].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0]])
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert len(generations) == len(next_batch)
|
||||
|
@ -323,7 +323,7 @@ def test_batch_concatenate(
|
|||
)
|
||||
assert generations[2].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[0], next_batch.requests[1]])
|
||||
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
@ -333,7 +333,7 @@ def test_batch_concatenate(
|
|||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
||||
next_batch = next_batch.filter([next_batch.requests[1]])
|
||||
next_batch = next_batch.filter([next_batch.requests[1].id])
|
||||
|
||||
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
|
|
@ -53,10 +53,10 @@ class CausalLMBatch(Batch):
|
|||
# Past metadata
|
||||
keys_head_dim_last: bool = True
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
return generate_pb2.Batch(
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
@ -143,16 +143,17 @@ class CausalLMBatch(Batch):
|
|||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
|
||||
if len(requests) == 0:
|
||||
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(requests) == len(self):
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
|
||||
keep_indices = []
|
||||
|
||||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
@ -165,11 +166,12 @@ class CausalLMBatch(Batch):
|
|||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
requests_idx_mapping[r.id] = i
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
keep_indices.append(idx)
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
|
@ -220,7 +222,7 @@ class CausalLMBatch(Batch):
|
|||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||
del past_values
|
||||
|
||||
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
|
||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||
|
||||
self.requests = requests
|
||||
self.requests_idx_mapping = requests_idx_mapping
|
||||
|
|
|
@ -62,10 +62,10 @@ class FlashCausalLMBatch(Batch):
|
|||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
return generate_pb2.Batch(
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
@ -161,14 +161,14 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
|
||||
if len(requests) == 0:
|
||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# We assume that if len(requests) == len(self) then the requests are the same
|
||||
if len(requests) == len(self):
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
|
||||
single_request = len(requests) == 1
|
||||
single_request = len(request_ids) == 1
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
@ -176,16 +176,17 @@ class FlashCausalLMBatch(Batch):
|
|||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
|
||||
input_ids = self.input_ids.new_empty(len(requests))
|
||||
position_ids = self.position_ids.new_empty(len(requests))
|
||||
input_ids = self.input_ids.new_empty(len(request_ids))
|
||||
position_ids = self.position_ids.new_empty(len(request_ids))
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32)
|
||||
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
|
||||
cu_seqlens_q = torch.arange(
|
||||
0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
|
||||
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
|
||||
)
|
||||
max_seqlen = 0
|
||||
past_key_values = []
|
||||
|
||||
requests = []
|
||||
all_input_ids = []
|
||||
all_input_ids_tensor = []
|
||||
|
||||
|
@ -198,9 +199,11 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
max_tokens = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
requests_idx_mapping[r.id] = i
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
|
||||
# Get length
|
||||
request_input_length = self.input_lengths[idx]
|
||||
|
|
|
@ -57,11 +57,11 @@ class Seq2SeqLMBatch(Batch):
|
|||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||
return generate_pb2.Batch(
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
@ -152,18 +152,17 @@ class Seq2SeqLMBatch(Batch):
|
|||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, requests: List[generate_pb2.Request]
|
||||
) -> Optional["Seq2SeqLMBatch"]:
|
||||
if len(requests) == 0:
|
||||
def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(requests) == len(self):
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
|
||||
keep_indices = []
|
||||
|
||||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
requests = []
|
||||
input_lengths = []
|
||||
decoder_input_lengths = []
|
||||
prefix_offsets = []
|
||||
|
@ -180,11 +179,12 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
total_remaining_decode_tokens = 0
|
||||
|
||||
for i, r in enumerate(requests):
|
||||
idx = self.requests_idx_mapping[r.id]
|
||||
requests_idx_mapping[r.id] = i
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
keep_indices.append(idx)
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
|
||||
|
@ -239,7 +239,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||
|
||||
max_tokens = (
|
||||
len(requests) * (max_input_length + max_decoder_input_length)
|
||||
len(request_ids) * (max_input_length + max_decoder_input_length)
|
||||
+ remaining_decode_tokens
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason
|
|||
|
||||
class Batch(ABC):
|
||||
@abstractmethod
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -26,7 +26,7 @@ class Batch(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, requests: List[generate_pb2.Request]) -> "Batch":
|
||||
def filter(self, request_ids: List[int]) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -42,15 +42,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
self.cache.delete(request.id)
|
||||
else:
|
||||
self.cache.clear()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return generate_pb2.ClearCacheResponse()
|
||||
|
||||
async def FilterBatch(self, request, context):
|
||||
batch = self.cache.pop(request.batch_id)
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||
filtered_batch = batch.filter(request.keep_requests)
|
||||
filtered_batch = batch.filter(request.request_ids)
|
||||
self.cache.set(filtered_batch)
|
||||
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
|
Loading…
Reference in New Issue