feat(router): Add max_waiting_tokens

This commit is contained in:
OlivierDehaene 2022-10-21 16:40:05 +02:00
parent 895a341d06
commit c837893370
7 changed files with 121 additions and 79 deletions

View File

@ -28,8 +28,8 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)]
@ -41,7 +41,7 @@ struct Args {
}
fn main() -> ExitCode {
tracing_subscriber::fmt::init();
tracing_subscriber::fmt().compact().with_ansi(false).init();
// Pattern match configuration
let Args {
@ -51,7 +51,7 @@ fn main() -> ExitCode {
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
max_waiting_tokens,
port,
shard_uds_path,
master_addr,
@ -148,8 +148,8 @@ fn main() -> ExitCode {
&max_input_length.to_string(),
"--max-batch-size",
&max_batch_size.to_string(),
"--max-waiting-time",
&max_waiting_time.to_string(),
"--max-waiting-tokens",
&max_waiting_tokens.to_string(),
"--port",
&port.to_string(),
"--master-shard-uds-path",

View File

@ -5,7 +5,6 @@ use axum::http::StatusCode;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
use tokio::time::Instant;
@ -30,7 +29,7 @@ impl Batcher {
pub(crate) fn new(
client: ShardedClient,
max_batch_size: usize,
max_waiting_time: Duration,
max_waiting_tokens: usize,
) -> Self {
// Batcher shared state
let db = Db::new();
@ -41,7 +40,7 @@ impl Batcher {
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
max_batch_size,
max_waiting_time,
max_waiting_tokens,
client,
db.clone(),
shared.clone(),
@ -55,7 +54,7 @@ impl Batcher {
&self,
input_length: usize,
request: GenerateRequest,
) -> Result<String, InferError> {
) -> Result<InferResponse, InferError> {
// One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel();
@ -65,6 +64,7 @@ impl Batcher {
response_tx,
input_length,
time: Instant::now(),
batch_time: None,
});
// Notify the background task that we have a new entry in the database that needs
@ -87,7 +87,7 @@ impl Batcher {
#[instrument(skip(client, db, shared))]
async fn batching_task(
max_batch_size: usize,
max_waiting_time: Duration,
max_waiting_tokens: usize,
client: ShardedClient,
db: Db,
shared: Arc<Shared>,
@ -103,8 +103,10 @@ async fn batching_task(
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
let mut waiting_tokens = 0;
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
waiting_tokens += 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
@ -116,10 +118,20 @@ async fn batching_task(
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
// Get the next batch from the DB that meet our minimum size criteria
let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
};
// Try to get a new batch
if let Some((new_request_ids, new_batch)) =
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
db.next_batch(min_size, max_batch_size)
{
// Reset waiting counter
waiting_tokens = 0;
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
@ -129,24 +141,11 @@ async fn batching_task(
batches.push(new_cached_batch);
}
}
// If we don't have enough requests to meet the minimum size criteria, we
// try to get the next batch from the DB that have been waiting over
// the max_waiting_time
else if let Some((new_request_ids, new_batch)) =
db.next_batch(None, max_batch_size, Some(max_waiting_time))
{
let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
}
cached_batch =
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
waiting_tokens += 1;
}
}
}
@ -188,11 +187,25 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let entry = db
.remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug.");
let response = InferResponse {
output: output.output,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
};
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Ok(output.output)).unwrap_or(());
entry.response_tx.send(Ok(response)).unwrap_or(());
});
}
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) output: String,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,
}
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]

View File

@ -1,10 +1,10 @@
use crate::InferResponse;
/// This code is massively inspired by Tokio mini-redis
use crate::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;
@ -14,11 +14,13 @@ pub(crate) struct Entry {
/// Request
pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<String, ClientError>>,
pub response_tx: Sender<Result<InferResponse, ClientError>>,
/// Number of tokens in the input
pub input_length: usize,
/// Instant when this entry was created
pub time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
}
/// Request Database
@ -51,11 +53,7 @@ struct State {
impl State {
/// Get the next requests
fn next_requests(
&self,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Vec<Request>)> {
fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
// Iterates for max_size over the BTreemap starting from next_batch_start_id
let mut requests = Vec::new();
let mut ids = Vec::new();
@ -67,15 +65,6 @@ impl State {
// Take max_size
.take(max_size)
{
if let Some(min_waiting_time) = min_waiting_time {
// Only take entries that waited for at least min_waiting_time
if entry.time.elapsed() < min_waiting_time {
// Since entries are ordered, we already know that all following entries won't
// satisfy the condition
break;
}
}
requests.push(Request {
id: *id,
inputs: entry.request.inputs.clone(),
@ -134,19 +123,22 @@ impl Db {
&self,
min_size: Option<usize>,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Batch)> {
// Acquire lock
let mut state = self.shared.state.lock();
// Get requests from the database
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
if let Some((ids, requests)) = state.next_requests(max_size) {
if let Some(min_size) = min_size {
// If min_size is set, only return a batch if there are enough requests
if requests.len() < min_size {
return None;
}
}
ids.iter().for_each(|id| {
// Set batch_time for each request
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
});
// Batch size
let size = requests.len();

View File

@ -4,7 +4,7 @@ mod db;
pub mod server;
mod validation;
use batcher::Batcher;
use batcher::{Batcher, InferResponse};
use db::{Db, Entry};
use serde::{Deserialize, Serialize};
use validation::Validation;
@ -64,5 +64,3 @@ pub(crate) struct GenerateRequest {
pub(crate) struct GeneratedText {
pub generated_text: String,
}
pub(crate) type GenerateResponse = Vec<GeneratedText>;

View File

@ -2,7 +2,6 @@
use bloom_inference_client::ShardedClient;
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use text_generation_router::server;
use tokenizers::Tokenizer;
@ -16,8 +15,8 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
@ -36,19 +35,19 @@ fn main() -> Result<(), std::io::Error> {
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
max_waiting_tokens,
port,
master_shard_uds_path,
tokenizer_name,
validation_workers,
} = args;
tracing_subscriber::fmt().compact().with_ansi(false).init();
if validation_workers == 1 {
panic!("validation_workers must be > 0");
}
let max_waiting_time = Duration::from_secs(max_waiting_time);
// Download and instantiate tokenizer
// This will only be used to validate payloads
//
@ -61,8 +60,6 @@ fn main() -> Result<(), std::io::Error> {
.build()
.unwrap()
.block_on(async {
tracing_subscriber::fmt::init();
// Instantiate sharded client from the master unix socket
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
@ -82,7 +79,7 @@ fn main() -> Result<(), std::io::Error> {
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
max_waiting_tokens,
sharded_client,
tokenizer,
validation_workers,

View File

@ -1,14 +1,12 @@
use crate::{
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
};
use crate::{Batcher, GenerateParameters, GenerateRequest, GeneratedText, Validation};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use bloom_inference_client::ShardedClient;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::Semaphore;
@ -59,12 +57,21 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
}
/// Generate method
#[instrument(skip(state), fields(time, time_per_token))]
#[instrument(
skip(state),
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token
)
)]
async fn generate(
state: Extension<ServerState>,
req: Json<GenerateRequest>,
) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
let start = Instant::now();
) -> Result<impl IntoResponse, (StatusCode, String)> {
let start_time = Instant::now();
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
@ -84,19 +91,51 @@ async fn generate(
.await?;
// Inference
let generated_text = state.batcher.infer(input_length, validated_request).await?;
let response = state.batcher.infer(input_length, validated_request).await?;
// Timings
let total_time = start_time.elapsed();
let validation_time = response.queued - start_time;
let queue_time = response.start - response.queued;
let inference_time = response.end - response.start;
let time_per_token = inference_time / req.parameters.max_new_tokens;
// Headers
let mut headers = HeaderMap::new();
headers.insert(
"x-total-time",
total_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-validation-time",
validation_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-queue-time",
queue_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-inference-time",
inference_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-time-per-token",
time_per_token.as_millis().to_string().parse().unwrap(),
);
// Tracing metadata
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record(
"time_per_token",
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
);
tracing::info!("response: {}", generated_text);
tracing::Span::current().record("total_time", format!("{:?}", total_time));
tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::info!("Output: {}", response.output);
// Send response
let response = vec![GeneratedText { generated_text }];
Ok(Json(response))
let response = vec![GeneratedText {
generated_text: response.output,
}];
Ok((headers, Json(response)))
}
/// Serving method
@ -105,14 +144,14 @@ pub async fn run(
max_concurrent_requests: usize,
max_input_length: usize,
max_batch_size: usize,
max_waiting_time: Duration,
max_waiting_tokens: usize,
client: ShardedClient,
tokenizer: Tokenizer,
validation_workers: usize,
addr: SocketAddr,
) {
// Create state
let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let shared_state = ServerState {
validation,

View File

@ -127,7 +127,10 @@ fn validation_worker(
if input_length > max_input_length {
response_tx
.send(Err(ValidationError::InputLength(input_length, max_input_length)))
.send(Err(ValidationError::InputLength(
input_length,
max_input_length,
)))
.unwrap_or(());
continue;
}