589 lines
21 KiB
Rust
589 lines
21 KiB
Rust
/// Batching and inference logic
|
|
use crate::validation::{Validation, ValidationError};
|
|
use crate::{Entry, Queue, Token};
|
|
use crate::{GenerateRequest, PrefillToken};
|
|
use flume::r#async::RecvStream;
|
|
use flume::SendError;
|
|
use futures::future::try_join_all;
|
|
use futures::stream::StreamExt;
|
|
use nohash_hasher::IntMap;
|
|
use std::sync::{
|
|
atomic::{AtomicBool, Ordering},
|
|
Arc,
|
|
};
|
|
use text_generation_client::{
|
|
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
|
};
|
|
use thiserror::Error;
|
|
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
|
use tokio::time::Instant;
|
|
use tracing::{info_span, instrument, Instrument, Span};
|
|
|
|
/// Inference struct
|
|
#[derive(Clone)]
|
|
pub struct Infer {
|
|
/// Validation
|
|
validation: Validation,
|
|
/// Request queue
|
|
queue: Queue,
|
|
/// Shared state
|
|
shared: Arc<Shared>,
|
|
/// Inference limit
|
|
limit_concurrent_requests: Arc<Semaphore>,
|
|
}
|
|
|
|
/// Infer shared state
|
|
struct Shared {
|
|
/// Batching background Tokio task notifier
|
|
batching_task: Notify,
|
|
}
|
|
|
|
impl Infer {
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub(crate) fn new(
|
|
client: ShardedClient,
|
|
validation: Validation,
|
|
waiting_served_ratio: f32,
|
|
max_batch_total_tokens: u32,
|
|
max_waiting_tokens: usize,
|
|
max_concurrent_requests: usize,
|
|
requires_padding: bool,
|
|
generation_health: Arc<AtomicBool>,
|
|
) -> Self {
|
|
// Infer shared state
|
|
let queue = Queue::new(requires_padding);
|
|
let shared = Arc::new(Shared {
|
|
batching_task: Notify::new(),
|
|
});
|
|
|
|
// Spawn batching background task that contains all the inference logic
|
|
tokio::spawn(batching_task(
|
|
client,
|
|
waiting_served_ratio,
|
|
max_batch_total_tokens,
|
|
max_waiting_tokens,
|
|
queue.clone(),
|
|
shared.clone(),
|
|
generation_health,
|
|
));
|
|
|
|
// Inference limit with a semaphore
|
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
|
|
|
Self {
|
|
validation,
|
|
queue,
|
|
shared,
|
|
limit_concurrent_requests: semaphore,
|
|
}
|
|
}
|
|
|
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
|
#[instrument(skip(self))]
|
|
pub(crate) async fn generate_stream(
|
|
&self,
|
|
request: GenerateRequest,
|
|
) -> Result<
|
|
(
|
|
OwnedSemaphorePermit,
|
|
RecvStream<Result<InferStreamResponse, InferError>>,
|
|
),
|
|
InferError,
|
|
> {
|
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
|
let permit = self
|
|
.clone()
|
|
.limit_concurrent_requests
|
|
.try_acquire_owned()
|
|
.map_err(|err| {
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
|
tracing::error!("{err}");
|
|
err
|
|
})?;
|
|
|
|
// Validate request
|
|
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
|
tracing::error!("{err}");
|
|
err
|
|
})?;
|
|
|
|
// MPSC channel to communicate with the background batching task
|
|
let (response_tx, response_rx) = flume::unbounded();
|
|
|
|
// Append the request to the queue
|
|
self.queue.append(Entry {
|
|
request: valid_request,
|
|
response_tx,
|
|
span: Span::current(),
|
|
temp_span: None,
|
|
queue_time: Instant::now(),
|
|
batch_time: None,
|
|
});
|
|
|
|
// Notify the background task that we have a new entry in the queue that needs
|
|
// to be batched
|
|
self.shared.batching_task.notify_one();
|
|
|
|
// Return stream
|
|
Ok((permit, response_rx.into_stream()))
|
|
}
|
|
|
|
/// Add a new request to the queue and return a InferResponse
|
|
#[instrument(skip(self))]
|
|
pub(crate) async fn generate(
|
|
&self,
|
|
request: GenerateRequest,
|
|
) -> Result<InferResponse, InferError> {
|
|
// Create stream and keep semaphore permit as long as generate lives
|
|
let (_permit, mut stream) = self.generate_stream(request).await?;
|
|
|
|
// Return values
|
|
let mut result_prefill = Vec::new();
|
|
let mut result_tokens = Vec::new();
|
|
let mut result_generated_text = None;
|
|
let mut result_start = None;
|
|
let mut result_queued = None;
|
|
|
|
// Iterate on stream
|
|
while let Some(response) = stream.next().await {
|
|
match response? {
|
|
// Add prefill tokens
|
|
InferStreamResponse::Prefill(tokens) => {
|
|
// Create Token objects
|
|
// We do that here instead of in the Python code as Rust for loops are faster
|
|
result_prefill = tokens
|
|
.ids
|
|
.into_iter()
|
|
.zip(tokens.logprobs.into_iter())
|
|
.zip(tokens.texts.into_iter())
|
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
|
.collect();
|
|
}
|
|
// Push last token
|
|
InferStreamResponse::Token(token) => result_tokens.push(token),
|
|
// Final message
|
|
// Set return values
|
|
InferStreamResponse::End {
|
|
token,
|
|
generated_text,
|
|
start,
|
|
queued,
|
|
} => {
|
|
result_tokens.push(token);
|
|
result_generated_text = Some(generated_text);
|
|
result_start = Some(start);
|
|
result_queued = Some(queued)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check that we received a `InferStreamResponse::End` message
|
|
if let (Some(generated_text), Some(queued), Some(start)) =
|
|
(result_generated_text, result_queued, result_start)
|
|
{
|
|
Ok(InferResponse {
|
|
prefill: result_prefill,
|
|
tokens: result_tokens,
|
|
generated_text,
|
|
queued,
|
|
start,
|
|
})
|
|
} else {
|
|
let err = InferError::IncompleteGeneration;
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
|
tracing::error!("{err}");
|
|
Err(err)
|
|
}
|
|
}
|
|
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
|
/// the highest log probability per token
|
|
#[instrument(skip(self))]
|
|
pub(crate) async fn generate_best_of(
|
|
&self,
|
|
request: GenerateRequest,
|
|
best_of: usize,
|
|
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
|
// validate best_of parameter separately
|
|
let best_of = self.validation.validate_best_of(best_of)?;
|
|
|
|
// create multiple generate requests
|
|
let mut infer_responses: Vec<InferResponse> =
|
|
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
|
|
|
// get the sequence with the highest log probability per token
|
|
let mut max_index = 0;
|
|
let mut max_logprob: f32 = f32::MIN;
|
|
|
|
for (i, response) in infer_responses.iter().enumerate() {
|
|
// mean logprobs of the generated tokens
|
|
let sequence_logprob = response
|
|
.tokens
|
|
.iter()
|
|
.map(|token| token.logprob)
|
|
.sum::<f32>()
|
|
/ response.tokens.len() as f32;
|
|
|
|
// set best sequence
|
|
if sequence_logprob > max_logprob {
|
|
max_index = i;
|
|
max_logprob = sequence_logprob;
|
|
}
|
|
}
|
|
let best_response = infer_responses.remove(max_index);
|
|
Ok((best_response, infer_responses))
|
|
}
|
|
}
|
|
|
|
/// Batching logic
|
|
/// Will be launched in a background Tokio task
|
|
///
|
|
/// Batches requests and sends them to the inference server
|
|
async fn batching_task(
|
|
mut client: ShardedClient,
|
|
waiting_served_ratio: f32,
|
|
max_batch_total_tokens: u32,
|
|
max_waiting_tokens: usize,
|
|
queue: Queue,
|
|
shared: Arc<Shared>,
|
|
generation_health: Arc<AtomicBool>,
|
|
) {
|
|
// Infinite loop
|
|
loop {
|
|
// Wait for a notification from the Infer struct
|
|
shared.batching_task.notified().await;
|
|
|
|
// Get the next batch from the queue
|
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
|
// waiting in the queue
|
|
while let Some((mut entries, batch, span)) =
|
|
queue.next_batch(None, max_batch_total_tokens).await
|
|
{
|
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
|
.instrument(span)
|
|
.await;
|
|
let mut waiting_tokens = 1;
|
|
|
|
// We loop until we do not receive any cached batch from the inference server (== until
|
|
// all requests have met their stopping criteria)
|
|
while let Some(batch) = cached_batch {
|
|
// Get current batch info
|
|
let batch_size = batch.size;
|
|
let batch_max_tokens = batch.max_tokens;
|
|
let mut batches = vec![batch];
|
|
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
|
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
|
|
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
|
// to add a new batch even though its size might be small
|
|
None
|
|
} else {
|
|
// Minimum batch size
|
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
|
};
|
|
|
|
let token_budget = max_batch_total_tokens - batch_max_tokens;
|
|
|
|
// Try to get a new batch
|
|
if let Some((mut new_entries, new_batch, span)) =
|
|
queue.next_batch(min_size, token_budget).await
|
|
{
|
|
// Tracking metrics
|
|
if min_size.is_some() {
|
|
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
|
} else {
|
|
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
|
}
|
|
|
|
entries.iter_mut().for_each(|(_, entry)| {
|
|
// Create a new span to add the info that this entry is waiting
|
|
// because a new batch is being computed
|
|
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
|
// Add relationships
|
|
span.follows_from(&entry_waiting_span);
|
|
entry_waiting_span.follows_from(&span);
|
|
// Update entry
|
|
entry.temp_span = Some(entry_waiting_span);
|
|
});
|
|
|
|
// Generate one token for this new batch to have the attention past in cache
|
|
let new_cached_batch =
|
|
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
|
.instrument(span)
|
|
.await;
|
|
// Reset waiting counter
|
|
waiting_tokens = 1;
|
|
// Extend current batch with the new batch
|
|
if let Some(new_cached_batch) = new_cached_batch {
|
|
entries.extend(new_entries);
|
|
batches.push(new_cached_batch);
|
|
}
|
|
}
|
|
|
|
// Create span for this batch to add context to inference calls
|
|
let next_batch_size = entries.len();
|
|
let next_batch_span =
|
|
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
|
entries.iter_mut().for_each(|(_, entry)| {
|
|
// Create a new span to link the batch back to this entry
|
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
|
// Add relationships
|
|
next_batch_span.follows_from(&entry_batch_span);
|
|
entry_batch_span.follows_from(&next_batch_span);
|
|
// Update entry
|
|
entry.temp_span = Some(entry_batch_span);
|
|
});
|
|
|
|
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
|
.instrument(next_batch_span)
|
|
.await;
|
|
waiting_tokens += 1;
|
|
}
|
|
metrics::gauge!("tgi_batch_current_size", 0.0);
|
|
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[instrument(skip_all)]
|
|
async fn prefill(
|
|
client: &mut ShardedClient,
|
|
batch: Batch,
|
|
entries: &mut IntMap<u64, Entry>,
|
|
generation_health: &Arc<AtomicBool>,
|
|
) -> Option<CachedBatch> {
|
|
let start_time = Instant::now();
|
|
let batch_id = batch.id;
|
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
|
|
|
match client.prefill(batch).await {
|
|
Ok((generations, next_batch)) => {
|
|
// Update health
|
|
generation_health.store(true, Ordering::SeqCst);
|
|
// Send generated tokens and filter stopped entries
|
|
filter_send_generations(generations, entries);
|
|
|
|
// Filter next batch and remove requests that were stopped
|
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
|
|
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
|
next_batch
|
|
}
|
|
// If we have an error, we discard the whole batch
|
|
Err(err) => {
|
|
// Update health
|
|
generation_health.store(false, Ordering::SeqCst);
|
|
let _ = client.clear_cache(Some(batch_id)).await;
|
|
send_errors(err, entries);
|
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
#[instrument(skip_all)]
|
|
async fn decode(
|
|
client: &mut ShardedClient,
|
|
batches: Vec<CachedBatch>,
|
|
entries: &mut IntMap<u64, Entry>,
|
|
generation_health: &Arc<AtomicBool>,
|
|
) -> Option<CachedBatch> {
|
|
let start_time = Instant::now();
|
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
|
|
|
match client.decode(batches).await {
|
|
Ok((generations, next_batch)) => {
|
|
// Update health
|
|
generation_health.store(true, Ordering::SeqCst);
|
|
// Send generated tokens and filter stopped entries
|
|
filter_send_generations(generations, entries);
|
|
|
|
// Filter next batch and remove requests that were stopped
|
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
|
|
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
|
next_batch
|
|
}
|
|
// If we have an error, we discard the whole batch
|
|
Err(err) => {
|
|
generation_health.store(false, Ordering::SeqCst);
|
|
for id in batch_ids {
|
|
let _ = client.clear_cache(Some(id)).await;
|
|
}
|
|
send_errors(err, entries);
|
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Filter a `batch` and remove all requests not present in `entries`
|
|
#[instrument(skip_all)]
|
|
async fn filter_batch(
|
|
client: &mut ShardedClient,
|
|
next_batch: Option<CachedBatch>,
|
|
entries: &IntMap<u64, Entry>,
|
|
) -> Option<CachedBatch> {
|
|
let mut batch = next_batch?;
|
|
|
|
// No need to filter
|
|
if batch.size as usize == entries.len() {
|
|
return Some(batch);
|
|
}
|
|
|
|
let id = batch.id;
|
|
|
|
// Retain only requests that are still in entries
|
|
batch.request_ids.retain(|id| entries.contains_key(id));
|
|
|
|
if batch.request_ids.is_empty() {
|
|
// All requests have been filtered out
|
|
// Next batch is now empty
|
|
// Clear it from the Python shards cache
|
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
|
client.clear_cache(Some(id)).await.unwrap();
|
|
None
|
|
} else {
|
|
// Filter Python shard cache
|
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
|
}
|
|
}
|
|
|
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
|
/// and filter entries
|
|
#[instrument(skip_all)]
|
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
|
generations.into_iter().for_each(|generation| {
|
|
let id = generation.request_id;
|
|
// Get entry
|
|
// We can `expect` here as the request id should always be in the entries
|
|
let entry = entries
|
|
.get(&id)
|
|
.expect("ID not found in entries. This is a bug.");
|
|
|
|
// Create and enter a span to link this function back to the entry
|
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
|
// Send generation responses back to the infer task
|
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
|
// request and we need to stop generating hence why we unwrap_or(true)
|
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
err
|
|
}).unwrap_or(true);
|
|
if stopped {
|
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
|
}
|
|
});
|
|
}
|
|
|
|
/// Send responses through the `entry` response channel
|
|
fn send_responses(
|
|
generation: Generation,
|
|
entry: &Entry,
|
|
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
|
|
let mut stopped = false;
|
|
|
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
|
// Send message
|
|
entry
|
|
.response_tx
|
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
|
}
|
|
|
|
// Create last Token
|
|
let token = Token {
|
|
id: generation.token_id,
|
|
text: generation.token_text,
|
|
logprob: generation.token_logprob,
|
|
special: generation.token_is_special,
|
|
};
|
|
|
|
if let Some(generated_text) = generation.generated_text {
|
|
// Generation has ended
|
|
stopped = true;
|
|
// Send message
|
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
|
token,
|
|
generated_text,
|
|
queued: entry.queue_time,
|
|
start: entry.batch_time.unwrap(),
|
|
}))?;
|
|
} else {
|
|
// Send message
|
|
entry
|
|
.response_tx
|
|
.send(Ok(InferStreamResponse::Token(token)))?;
|
|
}
|
|
Ok(stopped)
|
|
}
|
|
|
|
/// Send errors to Infer for all `entries`
|
|
#[instrument(skip_all)]
|
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|
entries.drain().for_each(|(_, entry)| {
|
|
// Create and enter a span to link this function back to the entry
|
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
|
let err = InferError::GenerationError(error.to_string());
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
|
tracing::error!("{err}");
|
|
|
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
entry
|
|
.response_tx
|
|
.send(Err(err))
|
|
.unwrap_or(());
|
|
});
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub(crate) enum InferStreamResponse {
|
|
// Optional first message
|
|
Prefill(PrefillTokens),
|
|
// Intermediate messages
|
|
Token(Token),
|
|
// Last message
|
|
End {
|
|
token: Token,
|
|
generated_text: GeneratedText,
|
|
start: Instant,
|
|
queued: Instant,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub(crate) struct InferResponse {
|
|
pub(crate) prefill: Vec<PrefillToken>,
|
|
pub(crate) tokens: Vec<Token>,
|
|
pub(crate) generated_text: GeneratedText,
|
|
pub(crate) queued: Instant,
|
|
pub(crate) start: Instant,
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum InferError {
|
|
#[error("Request failed during generation: {0}")]
|
|
GenerationError(String),
|
|
#[error("Model is overloaded")]
|
|
Overloaded(#[from] TryAcquireError),
|
|
#[error("Input validation error: {0}")]
|
|
ValidationError(#[from] ValidationError),
|
|
#[error("Incomplete generation")]
|
|
IncompleteGeneration,
|
|
}
|
|
|
|
impl InferError {
|
|
pub(crate) fn error_type(&self) -> &str {
|
|
match self {
|
|
InferError::GenerationError(_) => "generation",
|
|
InferError::Overloaded(_) => "overloaded",
|
|
InferError::ValidationError(_) => "validation",
|
|
InferError::IncompleteGeneration => "incomplete_generation",
|
|
}
|
|
}
|
|
}
|