FlashCausalLM implem
This commit is contained in:
parent
6983ec9537
commit
73c3903214
|
@ -224,11 +224,18 @@ message FilterBatchRequest {
|
|||
repeated uint64 terminated_request_ids = 3;
|
||||
}
|
||||
|
||||
message TerminatedGeneration {
|
||||
// Request ID
|
||||
uint64 id = 1;
|
||||
// Generated text
|
||||
GeneratedText generated_text = 2;
|
||||
}
|
||||
|
||||
message FilterBatchResponse {
|
||||
/// Filtered Batch (cached)
|
||||
CachedBatch batch = 1;
|
||||
/// Terminated generations
|
||||
repeated GeneratedText terminated_generations = 2;
|
||||
repeated TerminatedGeneration terminated_generations = 2;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ impl Client {
|
|||
batch_id: u64,
|
||||
kept_requests: Vec<KeptRequest>,
|
||||
terminated_request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
kept_requests,
|
||||
|
@ -100,7 +100,7 @@ impl Client {
|
|||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
Ok((filtered_batch.batch, filtered_batch.terminated_generations))
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
|
|
|
@ -8,6 +8,6 @@ pub use client::Client;
|
|||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
use crate::{v3, Health, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
|
||||
use crate::v3::{Chunk, InfoResponse, Input};
|
||||
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
|
@ -86,7 +86,7 @@ impl ShardedClient {
|
|||
batch_id: u64,
|
||||
kept_requests: Vec<KeptRequest>,
|
||||
terminated_request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
|
|
|
@ -5,12 +5,14 @@ use crate::infer::{
|
|||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient};
|
||||
use text_generation_client::v3::{
|
||||
Batch, CachedBatch, Generation, KeptRequest, ShardedClient, TerminatedGeneration,
|
||||
};
|
||||
use text_generation_client::ClientError;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||
|
@ -243,11 +245,38 @@ async fn prefill(
|
|||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
// Filter and send finished generations
|
||||
let filtered_stream_responses = filter_send_ended_generations(generations, entries);
|
||||
|
||||
// Iterate on intermediate generations
|
||||
for (id, stream_responses) in filtered_stream_responses {
|
||||
// Get entry
|
||||
let entry = entries
|
||||
.get_mut(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Send intermediate responses
|
||||
if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}) {
|
||||
// Sending failed, remove entry
|
||||
entries
|
||||
.remove(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
}
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries, false).await;
|
||||
let next_batch = match next_batch {
|
||||
Some(batch) if batch.size as usize != entries.len() => {
|
||||
let (filtered_batch, _) =
|
||||
filter_batch(client, batch, entries, &IntMap::default()).await;
|
||||
filtered_batch
|
||||
}
|
||||
batch => batch,
|
||||
};
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||
|
@ -285,13 +314,32 @@ async fn decode(
|
|||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
let updated = filter_update_allocations(entries).await;
|
||||
// Filter and send finished generations
|
||||
let mut filtered_stream_responses = filter_send_ended_generations(generations, entries);
|
||||
// Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be
|
||||
// re-allocated,
|
||||
// Allocated new blocks for entries that go over their allocation
|
||||
// Filter entries that couldn't be re-allocated and add them to `terminated_entries`
|
||||
let (force_update, terminated_entries) =
|
||||
filter_send_update_allocations(entries, &mut filtered_stream_responses);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries, updated).await;
|
||||
let next_batch = match next_batch {
|
||||
// Run Only on re-allocation or if entries were filtered
|
||||
Some(batch) if batch.size as usize != entries.len() || force_update => {
|
||||
// Filter next batch: remove requests that were stopped and update blocks/slots
|
||||
let (filtered_batch, terminated_generations) =
|
||||
filter_batch(client, batch, entries, &terminated_entries).await;
|
||||
send_terminated_generations(
|
||||
terminated_generations,
|
||||
terminated_entries,
|
||||
filtered_stream_responses,
|
||||
);
|
||||
|
||||
filtered_batch
|
||||
}
|
||||
batch => batch,
|
||||
};
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||
|
@ -320,27 +368,20 @@ async fn decode(
|
|||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
batch: CachedBatch,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
force_update: bool,
|
||||
) -> Option<CachedBatch> {
|
||||
let batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() && !force_update {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
terminated_entries: &IntMap<u64, Entry>,
|
||||
) -> (Option<CachedBatch>, Vec<TerminatedGeneration>) {
|
||||
let id = batch.id;
|
||||
if entries.is_empty() {
|
||||
if entries.is_empty() && terminated_entries.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
Default::default()
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// Collect new blocks/slots
|
||||
let updated_requests = entries
|
||||
.iter()
|
||||
.map(|(request_id, entry)| {
|
||||
|
@ -348,7 +389,7 @@ async fn filter_batch(
|
|||
.block_allocation
|
||||
.as_ref()
|
||||
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
|
||||
.unwrap_or((Vec::new(), Vec::new()));
|
||||
.unwrap_or_default();
|
||||
|
||||
KeptRequest {
|
||||
id: *request_id,
|
||||
|
@ -358,111 +399,207 @@ async fn filter_batch(
|
|||
})
|
||||
.collect();
|
||||
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client
|
||||
.filter_batch(id, updated_requests, Vec::new())
|
||||
.filter_batch(
|
||||
id,
|
||||
updated_requests,
|
||||
terminated_entries.keys().map(|v| *v).collect(),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
///
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
fn send_terminated_generations(
|
||||
terminated_generations: Vec<TerminatedGeneration>,
|
||||
terminated_entries: IntMap<u64, Entry>,
|
||||
mut stream_responses: IntMap<u64, Vec<InferStreamResponse>>,
|
||||
) {
|
||||
// Receive final message for terminated generations
|
||||
'terminated_generations: for terminated_generation in terminated_generations {
|
||||
let id = terminated_generation.id;
|
||||
// Get entry for this generation
|
||||
let entry = terminated_entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
// Get previous `InferStreamResponse` for this generation
|
||||
let stream_responses = stream_responses
|
||||
.remove(&id)
|
||||
.expect("ID not found in stream_responses. This is a bug.");
|
||||
|
||||
// Peekable iterator to know when we are at the last `InferStreamResponse`
|
||||
let mut iterator = stream_responses.into_iter().peekable();
|
||||
|
||||
while let Some(stream_response) = iterator.next() {
|
||||
let response = if iterator.peek().is_none() {
|
||||
// Last `InferStreamResponse::Intermediate`
|
||||
let (token, top_tokens) = match stream_response {
|
||||
InferStreamResponse::Intermediate { token, top_tokens } => (token, top_tokens),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
// Modify it to be a `InferStreamResponse::End` with the new OutOfResources finish
|
||||
// reason
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(
|
||||
terminated_generation
|
||||
.generated_text
|
||||
.clone()
|
||||
.expect("Generated Text is None. This is a bug."),
|
||||
),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}
|
||||
} else {
|
||||
stream_response
|
||||
};
|
||||
|
||||
// Send responses
|
||||
if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}) {
|
||||
continue 'terminated_generations;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send `InferStreamResponse::End` to `Infer` for finished entries and remove them from `entries`
|
||||
/// Returns filtered `InferStreamResponse::Intermediate` generations
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_ended_generations(
|
||||
generations: Vec<Generation>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> IntMap<u64, Vec<InferStreamResponse>> {
|
||||
generations.into_iter().filter_map(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get_mut(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
entry.cache_length = generation.cache_length;
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Check if block allocations need to be extended
|
||||
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
||||
let ids: Vec<u64> = entries
|
||||
.iter()
|
||||
.filter_map(|(id, entry)| {
|
||||
entry
|
||||
.block_allocation
|
||||
.as_ref()
|
||||
.map(|block_allocation| {
|
||||
if entry.cache_length > block_allocation.len() as u32 {
|
||||
// We need to re-allocate
|
||||
Some(*id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(None)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for id in ids.iter() {
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let extension = {
|
||||
let entry = entries
|
||||
.get_mut(id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
entry
|
||||
.block_allocation
|
||||
.as_mut()
|
||||
.expect("We checked that the block allocation exists above")
|
||||
.extend()
|
||||
};
|
||||
|
||||
if extension.is_err() {
|
||||
let entry = entries
|
||||
.remove(id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::OutOfPages;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Err(err)).unwrap_or(());
|
||||
}
|
||||
}
|
||||
|
||||
// If ids is not empty, we need to update
|
||||
!ids.is_empty()
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
return Ok(true);
|
||||
// Remove from entries and filter
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
// Update cache length
|
||||
entry.cache_length = generation.cache_length;
|
||||
|
||||
let (finished, stream_responses) = map_generation(generation, entry);
|
||||
// If the generation has ended for this request, we send the responses to the channel and
|
||||
// remove the entry to drop it and free its blocks
|
||||
if finished {
|
||||
let _ = send_stream_responses(stream_responses, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
});
|
||||
// Remove from entries and filter
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
return None;
|
||||
}
|
||||
|
||||
Some((id, stream_responses))
|
||||
}).collect()
|
||||
}
|
||||
|
||||
/// Send `InferStreamResponse` to `Infer` through an `Entry` response channel
|
||||
#[instrument(skip_all)]
|
||||
fn send_stream_responses(
|
||||
stream_responses: Vec<InferStreamResponse>,
|
||||
entry: &Entry,
|
||||
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
for response in stream_responses {
|
||||
entry.response_tx.send(Ok(response))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if block allocations need to be extended
|
||||
/// If we don't have enough blocks, request will be filtered with be added to an IntMap of
|
||||
/// terminated entries.
|
||||
/// If at least one entry allocation was extended, we return true to force an update
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_update_allocations(
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
stream_responses: &mut IntMap<u64, Vec<InferStreamResponse>>,
|
||||
) -> (bool, IntMap<u64, Entry>) {
|
||||
let mut updated = false;
|
||||
|
||||
let ids: Vec<u64> = entries.keys().map(|v| *v).collect();
|
||||
let mut terminated_entries =
|
||||
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
for id in &ids {
|
||||
let entry = entries
|
||||
.get_mut(id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
if let Some(block_allocation) = entry.block_allocation.as_mut() {
|
||||
// Check if allocation can handle the current cache_length
|
||||
if entry.cache_length > block_allocation.len() as u32 {
|
||||
updated = true;
|
||||
|
||||
// Extend allocation by asking for a new block
|
||||
if let Err(err) = block_allocation.extend() {
|
||||
// Failed to extend allocation
|
||||
tracing::error!("Failed to extend allocation: {err}");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_resources");
|
||||
|
||||
// Remove entry
|
||||
let mut entry = entries
|
||||
.remove(id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
// Clear block allocation
|
||||
entry.block_allocation = None;
|
||||
// Add it to terminated entries
|
||||
terminated_entries.insert(*id, entry);
|
||||
// Skip the rest of the logic to not send the intermediate messages
|
||||
// This entry will be terminated and we will need to edit the last intermediate
|
||||
// response to add the complete generated text
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
let stream_response = stream_responses
|
||||
.remove(id)
|
||||
.expect("ID not found in stream_responses. This is a bug.");
|
||||
|
||||
// Send intermediate responses
|
||||
if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}) {
|
||||
// Sending failed, remove entry
|
||||
entries
|
||||
.remove(id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
}
|
||||
|
||||
(updated, terminated_entries)
|
||||
}
|
||||
|
||||
/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>`
|
||||
fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec<InferStreamResponse>) {
|
||||
let mut finished = false;
|
||||
let mut stream_responses = Vec::with_capacity(16);
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
|
@ -475,10 +612,8 @@ fn send_responses(
|
|||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
// Push to stream_responses
|
||||
stream_responses.push(InferStreamResponse::Prefill(prefill_tokens));
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
|
@ -520,26 +655,24 @@ fn send_responses(
|
|||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
finished = true;
|
||||
// Push to stream_responses
|
||||
stream_responses.push(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
// Push to stream_responses
|
||||
stream_responses.push(InferStreamResponse::Intermediate { token, top_tokens });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
(finished, stream_responses)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
|
|
|
@ -402,10 +402,7 @@ class FlashCausalLMBatch(Batch):
|
|||
model: "FlashCausalLM",
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]:
|
||||
if len(kept_requests) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
|
||||
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||
terminated_generations = []
|
||||
for request_id in terminated_request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
|
@ -421,13 +418,19 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
generated_text = GeneratedText(
|
||||
output_text,
|
||||
stopping_criteria.current_tokens,
|
||||
generate_pb2.FINISH_REASON_TERMINATED,
|
||||
seed if do_sample else None,
|
||||
terminated_generations.append(
|
||||
generate_pb2.TerminatedGeneration(
|
||||
id=request_id,
|
||||
generated_text=generate_pb2.GeneratedText(
|
||||
text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||
seed=seed if do_sample else None,
|
||||
),
|
||||
)
|
||||
terminated_generations.append(generated_text)
|
||||
)
|
||||
if not kept_requests:
|
||||
return None, terminated_generations
|
||||
|
||||
device = self.input_ids.device
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class Batch(ABC):
|
|||
model,
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple["Batch", List[generate_pb2.GeneratedText]]:
|
||||
) -> Tuple[Optional["Batch"], List[generate_pb2.TerminatedGeneration]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -86,10 +86,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
filtered_batch, terminated_generations = batch.filter(
|
||||
self.model, request.kept_requests, request.terminated_request_ids
|
||||
)
|
||||
if filtered_batch is not None:
|
||||
self.cache.set(filtered_batch)
|
||||
|
||||
return generate_pb2.FilterBatchResponse(
|
||||
batch=filtered_batch.to_pb(), terminated_generations=terminated_generations
|
||||
batch=filtered_batch.to_pb() if filtered_batch is not None else None,
|
||||
terminated_generations=terminated_generations,
|
||||
)
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
|
|
Loading…
Reference in New Issue