
355 lines
12 KiB

/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest;
use crate::{Entry, Queue, Token};
use nohash_hasher::IntMap;
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
use thiserror::Error;
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;
/// Inference struct
pub struct Infer {
/// Validation
validation: Validation,
/// Request queue
queue: Queue,
/// Shared state
shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Infer shared state
struct Shared {
/// Batching background Tokio task notifier
batching_task: Notify,
impl Infer {
pub(crate) fn new(
client: ShardedClient,
validation: Validation,
max_batch_size: usize,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
) -> Self {
// Infer shared state
let queue = Queue::new();
let shared = Arc::new(Shared {
batching_task: Notify::new(),
// Spawn batching background task that contains all the inference logic
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
Self {
limit_concurrent_requests: semaphore,
/// Add a new request to the queue and return a stream of InferStreamResponse
pub(crate) async fn generate_stream(
request: GenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
// Validate request
let valid_request = self.validation.validate(request).await?;
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the queue
self.queue.append(Entry {
request: valid_request,
time: Instant::now(),
batch_time: None,
_permit: permit,
// Notify the background task that we have a new entry in the queue that needs
// to be batched
// Return stream
/// Add a new request to the queue and return a InferResponse
pub(crate) async fn generate(
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
// Create stream
let mut stream = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
// Iterate on stream
while let Some(response) = {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(tokens) => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.map(|((id, logprob), text)| Token { id, text, logprob })
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
// Final message
// Set return values
InferStreamResponse::End {
} => {
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
// Check that we received a `InferStreamResponse::End` message
if let (Some(generated_text), Some(queued), Some(start)) =
(result_generated_text, result_queued, result_start)
Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens,
} else {
/// Batching logic
/// Will be launched in a background Tokio task
/// Batches requests and sends them to the inference server
#[instrument(skip(client, queue, shared))]
async fn batching_task(
mut client: ShardedClient,
max_batch_size: usize,
max_waiting_tokens: usize,
queue: Queue,
shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = (max_batch_size / 2) as u32;
// Infinite loop
loop {
// Wait for a notification from the Infer struct
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let mut batches = vec![batch];
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
// Try to get a new batch
if let Some((mut new_entries, new_batch)) = queue
.next_batch(min_size, max_batch_size - batch_size as usize)
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries).await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
cached_batch = wrap_future(client.decode(batches), &mut entries).await;
waiting_tokens += 1;
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
match future.await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
// If we have an error, we discard the whole batch
Err(err) => {
send_error(err, entries);
/// Send errors to Infer for all `entries`
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// unwrap_or is valid here as we don't care if the receiver is gone.
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.expect("ID not found in entries. This is a bug.");
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let entry = entries
.expect("ID not found in entries. This is a bug.");
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
.send(Ok(InferStreamResponse::End {
queued: entry.time,
start: entry.batch_time.unwrap(),
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
pub(crate) enum InferStreamResponse {
// Optional first message
// Intermediate messages
// Last message
End {
token: Token,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
pub(crate) struct InferResponse {
pub(crate) prefill: Vec<Token>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
#[error("Model is overloaded")]
Overloaded(#[from] TryAcquireError),
#[error("Input validation error: {0}")]
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]