feat(router): Add max_waiting_tokens

This commit is contained in:
OlivierDehaene 2022-10-21 16:40:05 +02:00
parent 895a341d06
commit c837893370
7 changed files with 121 additions and 79 deletions

View File

@ -28,8 +28,8 @@ struct Args {
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "32", long, env)] #[clap(default_value = "32", long, env)]
max_batch_size: usize, max_batch_size: usize,
#[clap(default_value = "5", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_time: u64, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)] #[clap(default_value = "/tmp/text-generation-server", long, env)]
@ -41,7 +41,7 @@ struct Args {
} }
fn main() -> ExitCode { fn main() -> ExitCode {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt().compact().with_ansi(false).init();
// Pattern match configuration // Pattern match configuration
let Args { let Args {
@ -51,7 +51,7 @@ fn main() -> ExitCode {
max_concurrent_requests, max_concurrent_requests,
max_input_length, max_input_length,
max_batch_size, max_batch_size,
max_waiting_time, max_waiting_tokens,
port, port,
shard_uds_path, shard_uds_path,
master_addr, master_addr,
@ -148,8 +148,8 @@ fn main() -> ExitCode {
&max_input_length.to_string(), &max_input_length.to_string(),
"--max-batch-size", "--max-batch-size",
&max_batch_size.to_string(), &max_batch_size.to_string(),
"--max-waiting-time", "--max-waiting-tokens",
&max_waiting_time.to_string(), &max_waiting_tokens.to_string(),
"--port", "--port",
&port.to_string(), &port.to_string(),
"--master-shard-uds-path", "--master-shard-uds-path",

View File

@ -5,7 +5,6 @@ use axum::http::StatusCode;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use tokio::sync::{oneshot, Notify}; use tokio::sync::{oneshot, Notify};
use tokio::time::Instant; use tokio::time::Instant;
@ -30,7 +29,7 @@ impl Batcher {
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
max_batch_size: usize, max_batch_size: usize,
max_waiting_time: Duration, max_waiting_tokens: usize,
) -> Self { ) -> Self {
// Batcher shared state // Batcher shared state
let db = Db::new(); let db = Db::new();
@ -41,7 +40,7 @@ impl Batcher {
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task( tokio::spawn(batching_task(
max_batch_size, max_batch_size,
max_waiting_time, max_waiting_tokens,
client, client,
db.clone(), db.clone(),
shared.clone(), shared.clone(),
@ -55,7 +54,7 @@ impl Batcher {
&self, &self,
input_length: usize, input_length: usize,
request: GenerateRequest, request: GenerateRequest,
) -> Result<String, InferError> { ) -> Result<InferResponse, InferError> {
// One shot channel to communicate with the background batching task // One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel(); let (response_tx, response_rx) = oneshot::channel();
@ -65,6 +64,7 @@ impl Batcher {
response_tx, response_tx,
input_length, input_length,
time: Instant::now(), time: Instant::now(),
batch_time: None,
}); });
// Notify the background task that we have a new entry in the database that needs // Notify the background task that we have a new entry in the database that needs
@ -87,7 +87,7 @@ impl Batcher {
#[instrument(skip(client, db, shared))] #[instrument(skip(client, db, shared))]
async fn batching_task( async fn batching_task(
max_batch_size: usize, max_batch_size: usize,
max_waiting_time: Duration, max_waiting_tokens: usize,
client: ShardedClient, client: ShardedClient,
db: Db, db: Db,
shared: Arc<Shared>, shared: Arc<Shared>,
@ -103,8 +103,10 @@ async fn batching_task(
// Get the next batch from the DB // Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB // waiting in the DB
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) { let mut waiting_tokens = 0;
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
waiting_tokens += 1;
// We loop until we do not receive any cached batch from the inference server (== until // We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria) // all requests have met their stopping criteria)
@ -116,10 +118,20 @@ async fn batching_task(
// If the current batch is too small, we try to add more requests to it // If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size { if batch_size <= limit_min_batch_size {
// Get the next batch from the DB that meet our minimum size criteria let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
};
// Try to get a new batch
if let Some((new_request_ids, new_batch)) = if let Some((new_request_ids, new_batch)) =
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None) db.next_batch(min_size, max_batch_size)
{ {
// Reset waiting counter
waiting_tokens = 0;
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await; wrap_future(client.generate(new_batch), new_request_ids, &db).await;
@ -129,24 +141,11 @@ async fn batching_task(
batches.push(new_cached_batch); batches.push(new_cached_batch);
} }
} }
// If we don't have enough requests to meet the minimum size criteria, we
// try to get the next batch from the DB that have been waiting over
// the max_waiting_time
else if let Some((new_request_ids, new_batch)) =
db.next_batch(None, max_batch_size, Some(max_waiting_time))
{
let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
} }
cached_batch = cached_batch =
wrap_future(client.generate_with_cache(batches), request_ids, &db).await; wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
waiting_tokens += 1;
} }
} }
} }
@ -188,11 +187,25 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let entry = db let entry = db
.remove(&output.request.unwrap().id) .remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug."); .expect("ID not found in db. This is a bug.");
let response = InferResponse {
output: output.output,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
};
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Ok(output.output)).unwrap_or(()); entry.response_tx.send(Ok(response)).unwrap_or(());
}); });
} }
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) output: String,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum InferError { pub enum InferError {
#[error("Request failed during generation: {0}")] #[error("Request failed during generation: {0}")]

View File

@ -1,10 +1,10 @@
use crate::InferResponse;
/// This code is massively inspired by Tokio mini-redis /// This code is massively inspired by Tokio mini-redis
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request}; use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tokio::time::Instant; use tokio::time::Instant;
@ -14,11 +14,13 @@ pub(crate) struct Entry {
/// Request /// Request
pub request: GenerateRequest, pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task /// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<String, ClientError>>, pub response_tx: Sender<Result<InferResponse, ClientError>>,
/// Number of tokens in the input /// Number of tokens in the input
pub input_length: usize, pub input_length: usize,
/// Instant when this entry was created /// Instant when this entry was created
pub time: Instant, pub time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
} }
/// Request Database /// Request Database
@ -51,11 +53,7 @@ struct State {
impl State { impl State {
/// Get the next requests /// Get the next requests
fn next_requests( fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
&self,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Vec<Request>)> {
// Iterates for max_size over the BTreemap starting from next_batch_start_id // Iterates for max_size over the BTreemap starting from next_batch_start_id
let mut requests = Vec::new(); let mut requests = Vec::new();
let mut ids = Vec::new(); let mut ids = Vec::new();
@ -67,15 +65,6 @@ impl State {
// Take max_size // Take max_size
.take(max_size) .take(max_size)
{ {
if let Some(min_waiting_time) = min_waiting_time {
// Only take entries that waited for at least min_waiting_time
if entry.time.elapsed() < min_waiting_time {
// Since entries are ordered, we already know that all following entries won't
// satisfy the condition
break;
}
}
requests.push(Request { requests.push(Request {
id: *id, id: *id,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
@ -134,19 +123,22 @@ impl Db {
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
max_size: usize, max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Batch)> { ) -> Option<(Vec<u64>, Batch)> {
// Acquire lock // Acquire lock
let mut state = self.shared.state.lock(); let mut state = self.shared.state.lock();
// Get requests from the database // Get requests from the database
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) { if let Some((ids, requests)) = state.next_requests(max_size) {
if let Some(min_size) = min_size { if let Some(min_size) = min_size {
// If min_size is set, only return a batch if there are enough requests // If min_size is set, only return a batch if there are enough requests
if requests.len() < min_size { if requests.len() < min_size {
return None; return None;
} }
} }
ids.iter().for_each(|id| {
// Set batch_time for each request
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
});
// Batch size // Batch size
let size = requests.len(); let size = requests.len();

View File

@ -4,7 +4,7 @@ mod db;
pub mod server; pub mod server;
mod validation; mod validation;
use batcher::Batcher; use batcher::{Batcher, InferResponse};
use db::{Db, Entry}; use db::{Db, Entry};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validation::Validation; use validation::Validation;
@ -64,5 +64,3 @@ pub(crate) struct GenerateRequest {
pub(crate) struct GeneratedText { pub(crate) struct GeneratedText {
pub generated_text: String, pub generated_text: String,
} }
pub(crate) type GenerateResponse = Vec<GeneratedText>;

View File

@ -2,7 +2,6 @@
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use clap::Parser; use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -16,8 +15,8 @@ struct Args {
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "32", long, env)] #[clap(default_value = "32", long, env)]
max_batch_size: usize, max_batch_size: usize,
#[clap(default_value = "5", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_time: u64, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)] #[clap(default_value = "/tmp/bloom-inference-0", long, env)]
@ -36,19 +35,19 @@ fn main() -> Result<(), std::io::Error> {
max_concurrent_requests, max_concurrent_requests,
max_input_length, max_input_length,
max_batch_size, max_batch_size,
max_waiting_time, max_waiting_tokens,
port, port,
master_shard_uds_path, master_shard_uds_path,
tokenizer_name, tokenizer_name,
validation_workers, validation_workers,
} = args; } = args;
tracing_subscriber::fmt().compact().with_ansi(false).init();
if validation_workers == 1 { if validation_workers == 1 {
panic!("validation_workers must be > 0"); panic!("validation_workers must be > 0");
} }
let max_waiting_time = Duration::from_secs(max_waiting_time);
// Download and instantiate tokenizer // Download and instantiate tokenizer
// This will only be used to validate payloads // This will only be used to validate payloads
// //
@ -61,8 +60,6 @@ fn main() -> Result<(), std::io::Error> {
.build() .build()
.unwrap() .unwrap()
.block_on(async { .block_on(async {
tracing_subscriber::fmt::init();
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
@ -82,7 +79,7 @@ fn main() -> Result<(), std::io::Error> {
max_concurrent_requests, max_concurrent_requests,
max_input_length, max_input_length,
max_batch_size, max_batch_size,
max_waiting_time, max_waiting_tokens,
sharded_client, sharded_client,
tokenizer, tokenizer,
validation_workers, validation_workers,

View File

@ -1,14 +1,12 @@
use crate::{ use crate::{Batcher, GenerateParameters, GenerateRequest, GeneratedText, Validation};
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::StatusCode; use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
@ -59,12 +57,21 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
} }
/// Generate method /// Generate method
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(
skip(state),
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token
)
)]
async fn generate( async fn generate(
state: Extension<ServerState>, state: Extension<ServerState>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<Json<GenerateResponse>, (StatusCode, String)> { ) -> Result<impl IntoResponse, (StatusCode, String)> {
let start = Instant::now(); let start_time = Instant::now();
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
( (
@ -84,19 +91,51 @@ async fn generate(
.await?; .await?;
// Inference // Inference
let generated_text = state.batcher.infer(input_length, validated_request).await?; let response = state.batcher.infer(input_length, validated_request).await?;
// Timings
let total_time = start_time.elapsed();
let validation_time = response.queued - start_time;
let queue_time = response.start - response.queued;
let inference_time = response.end - response.start;
let time_per_token = inference_time / req.parameters.max_new_tokens;
// Headers
let mut headers = HeaderMap::new();
headers.insert(
"x-total-time",
total_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-validation-time",
validation_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-queue-time",
queue_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-inference-time",
inference_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-time-per-token",
time_per_token.as_millis().to_string().parse().unwrap(),
);
// Tracing metadata // Tracing metadata
tracing::Span::current().record("time", format!("{:?}", start.elapsed())); tracing::Span::current().record("total_time", format!("{:?}", total_time));
tracing::Span::current().record( tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
"time_per_token", tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens), tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
); tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::info!("response: {}", generated_text); tracing::info!("Output: {}", response.output);
// Send response // Send response
let response = vec![GeneratedText { generated_text }]; let response = vec![GeneratedText {
Ok(Json(response)) generated_text: response.output,
}];
Ok((headers, Json(response)))
} }
/// Serving method /// Serving method
@ -105,14 +144,14 @@ pub async fn run(
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_input_length: usize, max_input_length: usize,
max_batch_size: usize, max_batch_size: usize,
max_waiting_time: Duration, max_waiting_tokens: usize,
client: ShardedClient, client: ShardedClient,
tokenizer: Tokenizer, tokenizer: Tokenizer,
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
) { ) {
// Create state // Create state
let batcher = Batcher::new(client, max_batch_size, max_waiting_time); let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
let validation = Validation::new(validation_workers, tokenizer, max_input_length); let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let shared_state = ServerState { let shared_state = ServerState {
validation, validation,

View File

@ -127,7 +127,10 @@ fn validation_worker(
if input_length > max_input_length { if input_length > max_input_length {
response_tx response_tx
.send(Err(ValidationError::InputLength(input_length, max_input_length))) .send(Err(ValidationError::InputLength(
input_length,
max_input_length,
)))
.unwrap_or(()); .unwrap_or(());
continue; continue;
} }