feat(router): Add max_waiting_tokens
This commit is contained in:
parent
895a341d06
commit
c837893370
|
@ -28,8 +28,8 @@ struct Args {
|
|||
max_input_length: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_waiting_time: u64,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server", long, env)]
|
||||
|
@ -41,7 +41,7 @@ struct Args {
|
|||
}
|
||||
|
||||
fn main() -> ExitCode {
|
||||
tracing_subscriber::fmt::init();
|
||||
tracing_subscriber::fmt().compact().with_ansi(false).init();
|
||||
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
|
@ -51,7 +51,7 @@ fn main() -> ExitCode {
|
|||
max_concurrent_requests,
|
||||
max_input_length,
|
||||
max_batch_size,
|
||||
max_waiting_time,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
shard_uds_path,
|
||||
master_addr,
|
||||
|
@ -148,8 +148,8 @@ fn main() -> ExitCode {
|
|||
&max_input_length.to_string(),
|
||||
"--max-batch-size",
|
||||
&max_batch_size.to_string(),
|
||||
"--max-waiting-time",
|
||||
&max_waiting_time.to_string(),
|
||||
"--max-waiting-tokens",
|
||||
&max_waiting_tokens.to_string(),
|
||||
"--port",
|
||||
&port.to_string(),
|
||||
"--master-shard-uds-path",
|
||||
|
|
|
@ -5,7 +5,6 @@ use axum::http::StatusCode;
|
|||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use tokio::time::Instant;
|
||||
|
@ -30,7 +29,7 @@ impl Batcher {
|
|||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
max_batch_size: usize,
|
||||
max_waiting_time: Duration,
|
||||
max_waiting_tokens: usize,
|
||||
) -> Self {
|
||||
// Batcher shared state
|
||||
let db = Db::new();
|
||||
|
@ -41,7 +40,7 @@ impl Batcher {
|
|||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
max_batch_size,
|
||||
max_waiting_time,
|
||||
max_waiting_tokens,
|
||||
client,
|
||||
db.clone(),
|
||||
shared.clone(),
|
||||
|
@ -55,7 +54,7 @@ impl Batcher {
|
|||
&self,
|
||||
input_length: usize,
|
||||
request: GenerateRequest,
|
||||
) -> Result<String, InferError> {
|
||||
) -> Result<InferResponse, InferError> {
|
||||
// One shot channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
|
||||
|
@ -65,6 +64,7 @@ impl Batcher {
|
|||
response_tx,
|
||||
input_length,
|
||||
time: Instant::now(),
|
||||
batch_time: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the database that needs
|
||||
|
@ -87,7 +87,7 @@ impl Batcher {
|
|||
#[instrument(skip(client, db, shared))]
|
||||
async fn batching_task(
|
||||
max_batch_size: usize,
|
||||
max_waiting_time: Duration,
|
||||
max_waiting_tokens: usize,
|
||||
client: ShardedClient,
|
||||
db: Db,
|
||||
shared: Arc<Shared>,
|
||||
|
@ -103,8 +103,10 @@ async fn batching_task(
|
|||
// Get the next batch from the DB
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the DB
|
||||
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
|
||||
let mut waiting_tokens = 0;
|
||||
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
||||
waiting_tokens += 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
|
@ -116,10 +118,20 @@ async fn batching_task(
|
|||
|
||||
// If the current batch is too small, we try to add more requests to it
|
||||
if batch_size <= limit_min_batch_size {
|
||||
// Get the next batch from the DB that meet our minimum size criteria
|
||||
let min_size = match waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
_ if waiting_tokens >= max_waiting_tokens => None,
|
||||
// Minimum size criteria
|
||||
_ => Some(limit_min_batch_size as usize),
|
||||
};
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((new_request_ids, new_batch)) =
|
||||
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
|
||||
db.next_batch(min_size, max_batch_size)
|
||||
{
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 0;
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||
|
@ -129,24 +141,11 @@ async fn batching_task(
|
|||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
// If we don't have enough requests to meet the minimum size criteria, we
|
||||
// try to get the next batch from the DB that have been waiting over
|
||||
// the max_waiting_time
|
||||
else if let Some((new_request_ids, new_batch)) =
|
||||
db.next_batch(None, max_batch_size, Some(max_waiting_time))
|
||||
{
|
||||
let new_cached_batch =
|
||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cached_batch =
|
||||
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -188,11 +187,25 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
|||
let entry = db
|
||||
.remove(&output.request.unwrap().id)
|
||||
.expect("ID not found in db. This is a bug.");
|
||||
let response = InferResponse {
|
||||
output: output.output,
|
||||
queued: entry.time,
|
||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||
end: Instant::now(),
|
||||
};
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Ok(output.output)).unwrap_or(());
|
||||
entry.response_tx.send(Ok(response)).unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct InferResponse {
|
||||
pub(crate) output: String,
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) end: Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InferError {
|
||||
#[error("Request failed during generation: {0}")]
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use crate::InferResponse;
|
||||
/// This code is massively inspired by Tokio mini-redis
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::Instant;
|
||||
|
||||
|
@ -14,11 +14,13 @@ pub(crate) struct Entry {
|
|||
/// Request
|
||||
pub request: GenerateRequest,
|
||||
/// Response sender to communicate between the Batcher and the batching_task
|
||||
pub response_tx: Sender<Result<String, ClientError>>,
|
||||
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
||||
/// Number of tokens in the input
|
||||
pub input_length: usize,
|
||||
/// Instant when this entry was created
|
||||
pub time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
pub batch_time: Option<Instant>,
|
||||
}
|
||||
|
||||
/// Request Database
|
||||
|
@ -51,11 +53,7 @@ struct State {
|
|||
|
||||
impl State {
|
||||
/// Get the next requests
|
||||
fn next_requests(
|
||||
&self,
|
||||
max_size: usize,
|
||||
min_waiting_time: Option<Duration>,
|
||||
) -> Option<(Vec<u64>, Vec<Request>)> {
|
||||
fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
|
||||
// Iterates for max_size over the BTreemap starting from next_batch_start_id
|
||||
let mut requests = Vec::new();
|
||||
let mut ids = Vec::new();
|
||||
|
@ -67,15 +65,6 @@ impl State {
|
|||
// Take max_size
|
||||
.take(max_size)
|
||||
{
|
||||
if let Some(min_waiting_time) = min_waiting_time {
|
||||
// Only take entries that waited for at least min_waiting_time
|
||||
if entry.time.elapsed() < min_waiting_time {
|
||||
// Since entries are ordered, we already know that all following entries won't
|
||||
// satisfy the condition
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: *id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
|
@ -134,19 +123,22 @@ impl Db {
|
|||
&self,
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
min_waiting_time: Option<Duration>,
|
||||
) -> Option<(Vec<u64>, Batch)> {
|
||||
// Acquire lock
|
||||
let mut state = self.shared.state.lock();
|
||||
|
||||
// Get requests from the database
|
||||
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
|
||||
if let Some((ids, requests)) = state.next_requests(max_size) {
|
||||
if let Some(min_size) = min_size {
|
||||
// If min_size is set, only return a batch if there are enough requests
|
||||
if requests.len() < min_size {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
ids.iter().for_each(|id| {
|
||||
// Set batch_time for each request
|
||||
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
|
||||
});
|
||||
|
||||
// Batch size
|
||||
let size = requests.len();
|
||||
|
|
|
@ -4,7 +4,7 @@ mod db;
|
|||
pub mod server;
|
||||
mod validation;
|
||||
|
||||
use batcher::Batcher;
|
||||
use batcher::{Batcher, InferResponse};
|
||||
use db::{Db, Entry};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use validation::Validation;
|
||||
|
@ -64,5 +64,3 @@ pub(crate) struct GenerateRequest {
|
|||
pub(crate) struct GeneratedText {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub(crate) type GenerateResponse = Vec<GeneratedText>;
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
use bloom_inference_client::ShardedClient;
|
||||
use clap::Parser;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
|
@ -16,8 +15,8 @@ struct Args {
|
|||
max_input_length: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_waiting_time: u64,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
||||
|
@ -36,19 +35,19 @@ fn main() -> Result<(), std::io::Error> {
|
|||
max_concurrent_requests,
|
||||
max_input_length,
|
||||
max_batch_size,
|
||||
max_waiting_time,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
validation_workers,
|
||||
} = args;
|
||||
|
||||
tracing_subscriber::fmt().compact().with_ansi(false).init();
|
||||
|
||||
if validation_workers == 1 {
|
||||
panic!("validation_workers must be > 0");
|
||||
}
|
||||
|
||||
let max_waiting_time = Duration::from_secs(max_waiting_time);
|
||||
|
||||
// Download and instantiate tokenizer
|
||||
// This will only be used to validate payloads
|
||||
//
|
||||
|
@ -61,8 +60,6 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// Instantiate sharded client from the master unix socket
|
||||
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
|
@ -82,7 +79,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
max_concurrent_requests,
|
||||
max_input_length,
|
||||
max_batch_size,
|
||||
max_waiting_time,
|
||||
max_waiting_tokens,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
validation_workers,
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
use crate::{
|
||||
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
|
||||
};
|
||||
use crate::{Batcher, GenerateParameters, GenerateRequest, GeneratedText, Validation};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::sync::Semaphore;
|
||||
|
@ -59,12 +57,21 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
|
|||
}
|
||||
|
||||
/// Generate method
|
||||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
#[instrument(
|
||||
skip(state),
|
||||
fields(
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token
|
||||
)
|
||||
)]
|
||||
async fn generate(
|
||||
state: Extension<ServerState>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
|
||||
let start = Instant::now();
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let start_time = Instant::now();
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||
(
|
||||
|
@ -84,19 +91,51 @@ async fn generate(
|
|||
.await?;
|
||||
|
||||
// Inference
|
||||
let generated_text = state.batcher.infer(input_length, validated_request).await?;
|
||||
let response = state.batcher.infer(input_length, validated_request).await?;
|
||||
|
||||
// Timings
|
||||
let total_time = start_time.elapsed();
|
||||
let validation_time = response.queued - start_time;
|
||||
let queue_time = response.start - response.queued;
|
||||
let inference_time = response.end - response.start;
|
||||
let time_per_token = inference_time / req.parameters.max_new_tokens;
|
||||
|
||||
// Headers
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"x-total-time",
|
||||
total_time.as_millis().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-validation-time",
|
||||
validation_time.as_millis().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-queue-time",
|
||||
queue_time.as_millis().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-inference-time",
|
||||
inference_time.as_millis().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-time-per-token",
|
||||
time_per_token.as_millis().to_string().parse().unwrap(),
|
||||
);
|
||||
|
||||
// Tracing metadata
|
||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||
tracing::Span::current().record(
|
||||
"time_per_token",
|
||||
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
|
||||
);
|
||||
tracing::info!("response: {}", generated_text);
|
||||
tracing::Span::current().record("total_time", format!("{:?}", total_time));
|
||||
tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
|
||||
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
||||
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
||||
tracing::info!("Output: {}", response.output);
|
||||
|
||||
// Send response
|
||||
let response = vec![GeneratedText { generated_text }];
|
||||
Ok(Json(response))
|
||||
let response = vec![GeneratedText {
|
||||
generated_text: response.output,
|
||||
}];
|
||||
Ok((headers, Json(response)))
|
||||
}
|
||||
|
||||
/// Serving method
|
||||
|
@ -105,14 +144,14 @@ pub async fn run(
|
|||
max_concurrent_requests: usize,
|
||||
max_input_length: usize,
|
||||
max_batch_size: usize,
|
||||
max_waiting_time: Duration,
|
||||
max_waiting_tokens: usize,
|
||||
client: ShardedClient,
|
||||
tokenizer: Tokenizer,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
) {
|
||||
// Create state
|
||||
let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
|
||||
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
|
||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||
let shared_state = ServerState {
|
||||
validation,
|
||||
|
|
|
@ -127,7 +127,10 @@ fn validation_worker(
|
|||
|
||||
if input_length > max_input_length {
|
||||
response_tx
|
||||
.send(Err(ValidationError::InputLength(input_length, max_input_length)))
|
||||
.send(Err(ValidationError::InputLength(
|
||||
input_length,
|
||||
max_input_length,
|
||||
)))
|
||||
.unwrap_or(());
|
||||
continue;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue