feat(router): Add max_waiting_tokens
This commit is contained in:
parent
895a341d06
commit
c837893370
|
@ -28,8 +28,8 @@ struct Args {
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "32", long, env)]
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
#[clap(default_value = "5", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_time: u64,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/text-generation-server", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server", long, env)]
|
||||||
|
@ -41,7 +41,7 @@ struct Args {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt().compact().with_ansi(false).init();
|
||||||
|
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
|
@ -51,7 +51,7 @@ fn main() -> ExitCode {
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_time,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
shard_uds_path,
|
shard_uds_path,
|
||||||
master_addr,
|
master_addr,
|
||||||
|
@ -148,8 +148,8 @@ fn main() -> ExitCode {
|
||||||
&max_input_length.to_string(),
|
&max_input_length.to_string(),
|
||||||
"--max-batch-size",
|
"--max-batch-size",
|
||||||
&max_batch_size.to_string(),
|
&max_batch_size.to_string(),
|
||||||
"--max-waiting-time",
|
"--max-waiting-tokens",
|
||||||
&max_waiting_time.to_string(),
|
&max_waiting_tokens.to_string(),
|
||||||
"--port",
|
"--port",
|
||||||
&port.to_string(),
|
&port.to_string(),
|
||||||
"--master-shard-uds-path",
|
"--master-shard-uds-path",
|
||||||
|
|
|
@ -5,7 +5,6 @@ use axum::http::StatusCode;
|
||||||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{oneshot, Notify};
|
use tokio::sync::{oneshot, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
@ -30,7 +29,7 @@ impl Batcher {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_time: Duration,
|
max_waiting_tokens: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Batcher shared state
|
// Batcher shared state
|
||||||
let db = Db::new();
|
let db = Db::new();
|
||||||
|
@ -41,7 +40,7 @@ impl Batcher {
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
tokio::spawn(batching_task(
|
tokio::spawn(batching_task(
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_time,
|
max_waiting_tokens,
|
||||||
client,
|
client,
|
||||||
db.clone(),
|
db.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
|
@ -55,7 +54,7 @@ impl Batcher {
|
||||||
&self,
|
&self,
|
||||||
input_length: usize,
|
input_length: usize,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<InferResponse, InferError> {
|
||||||
// One shot channel to communicate with the background batching task
|
// One shot channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = oneshot::channel();
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
|
|
||||||
|
@ -65,6 +64,7 @@ impl Batcher {
|
||||||
response_tx,
|
response_tx,
|
||||||
input_length,
|
input_length,
|
||||||
time: Instant::now(),
|
time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the database that needs
|
// Notify the background task that we have a new entry in the database that needs
|
||||||
|
@ -87,7 +87,7 @@ impl Batcher {
|
||||||
#[instrument(skip(client, db, shared))]
|
#[instrument(skip(client, db, shared))]
|
||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_time: Duration,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
db: Db,
|
db: Db,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
|
@ -103,8 +103,10 @@ async fn batching_task(
|
||||||
// Get the next batch from the DB
|
// Get the next batch from the DB
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the DB
|
// waiting in the DB
|
||||||
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
|
let mut waiting_tokens = 0;
|
||||||
|
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
||||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
||||||
|
waiting_tokens += 1;
|
||||||
|
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
// all requests have met their stopping criteria)
|
// all requests have met their stopping criteria)
|
||||||
|
@ -116,10 +118,20 @@ async fn batching_task(
|
||||||
|
|
||||||
// If the current batch is too small, we try to add more requests to it
|
// If the current batch is too small, we try to add more requests to it
|
||||||
if batch_size <= limit_min_batch_size {
|
if batch_size <= limit_min_batch_size {
|
||||||
// Get the next batch from the DB that meet our minimum size criteria
|
let min_size = match waiting_tokens {
|
||||||
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
|
// to add a new batch even though its size might be small
|
||||||
|
_ if waiting_tokens >= max_waiting_tokens => None,
|
||||||
|
// Minimum size criteria
|
||||||
|
_ => Some(limit_min_batch_size as usize),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Try to get a new batch
|
||||||
if let Some((new_request_ids, new_batch)) =
|
if let Some((new_request_ids, new_batch)) =
|
||||||
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
|
db.next_batch(min_size, max_batch_size)
|
||||||
{
|
{
|
||||||
|
// Reset waiting counter
|
||||||
|
waiting_tokens = 0;
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||||
|
@ -129,24 +141,11 @@ async fn batching_task(
|
||||||
batches.push(new_cached_batch);
|
batches.push(new_cached_batch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If we don't have enough requests to meet the minimum size criteria, we
|
|
||||||
// try to get the next batch from the DB that have been waiting over
|
|
||||||
// the max_waiting_time
|
|
||||||
else if let Some((new_request_ids, new_batch)) =
|
|
||||||
db.next_batch(None, max_batch_size, Some(max_waiting_time))
|
|
||||||
{
|
|
||||||
let new_cached_batch =
|
|
||||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
|
||||||
// Extend current batch with the new batch
|
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
|
||||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
|
||||||
batches.push(new_cached_batch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_batch =
|
cached_batch =
|
||||||
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
|
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
|
||||||
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -188,11 +187,25 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||||
let entry = db
|
let entry = db
|
||||||
.remove(&output.request.unwrap().id)
|
.remove(&output.request.unwrap().id)
|
||||||
.expect("ID not found in db. This is a bug.");
|
.expect("ID not found in db. This is a bug.");
|
||||||
|
let response = InferResponse {
|
||||||
|
output: output.output,
|
||||||
|
queued: entry.time,
|
||||||
|
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||||
|
end: Instant::now(),
|
||||||
|
};
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry.response_tx.send(Ok(output.output)).unwrap_or(());
|
entry.response_tx.send(Ok(response)).unwrap_or(());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct InferResponse {
|
||||||
|
pub(crate) output: String,
|
||||||
|
pub(crate) queued: Instant,
|
||||||
|
pub(crate) start: Instant,
|
||||||
|
pub(crate) end: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum InferError {
|
pub enum InferError {
|
||||||
#[error("Request failed during generation: {0}")]
|
#[error("Request failed during generation: {0}")]
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
|
use crate::InferResponse;
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::oneshot::Sender;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
@ -14,11 +14,13 @@ pub(crate) struct Entry {
|
||||||
/// Request
|
/// Request
|
||||||
pub request: GenerateRequest,
|
pub request: GenerateRequest,
|
||||||
/// Response sender to communicate between the Batcher and the batching_task
|
/// Response sender to communicate between the Batcher and the batching_task
|
||||||
pub response_tx: Sender<Result<String, ClientError>>,
|
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
||||||
/// Number of tokens in the input
|
/// Number of tokens in the input
|
||||||
pub input_length: usize,
|
pub input_length: usize,
|
||||||
/// Instant when this entry was created
|
/// Instant when this entry was created
|
||||||
pub time: Instant,
|
pub time: Instant,
|
||||||
|
/// Instant when this entry was added to a batch
|
||||||
|
pub batch_time: Option<Instant>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Database
|
/// Request Database
|
||||||
|
@ -51,11 +53,7 @@ struct State {
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
/// Get the next requests
|
/// Get the next requests
|
||||||
fn next_requests(
|
fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
|
||||||
&self,
|
|
||||||
max_size: usize,
|
|
||||||
min_waiting_time: Option<Duration>,
|
|
||||||
) -> Option<(Vec<u64>, Vec<Request>)> {
|
|
||||||
// Iterates for max_size over the BTreemap starting from next_batch_start_id
|
// Iterates for max_size over the BTreemap starting from next_batch_start_id
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
let mut ids = Vec::new();
|
let mut ids = Vec::new();
|
||||||
|
@ -67,15 +65,6 @@ impl State {
|
||||||
// Take max_size
|
// Take max_size
|
||||||
.take(max_size)
|
.take(max_size)
|
||||||
{
|
{
|
||||||
if let Some(min_waiting_time) = min_waiting_time {
|
|
||||||
// Only take entries that waited for at least min_waiting_time
|
|
||||||
if entry.time.elapsed() < min_waiting_time {
|
|
||||||
// Since entries are ordered, we already know that all following entries won't
|
|
||||||
// satisfy the condition
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: *id,
|
id: *id,
|
||||||
inputs: entry.request.inputs.clone(),
|
inputs: entry.request.inputs.clone(),
|
||||||
|
@ -134,19 +123,22 @@ impl Db {
|
||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
max_size: usize,
|
max_size: usize,
|
||||||
min_waiting_time: Option<Duration>,
|
|
||||||
) -> Option<(Vec<u64>, Batch)> {
|
) -> Option<(Vec<u64>, Batch)> {
|
||||||
// Acquire lock
|
// Acquire lock
|
||||||
let mut state = self.shared.state.lock();
|
let mut state = self.shared.state.lock();
|
||||||
|
|
||||||
// Get requests from the database
|
// Get requests from the database
|
||||||
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
|
if let Some((ids, requests)) = state.next_requests(max_size) {
|
||||||
if let Some(min_size) = min_size {
|
if let Some(min_size) = min_size {
|
||||||
// If min_size is set, only return a batch if there are enough requests
|
// If min_size is set, only return a batch if there are enough requests
|
||||||
if requests.len() < min_size {
|
if requests.len() < min_size {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ids.iter().for_each(|id| {
|
||||||
|
// Set batch_time for each request
|
||||||
|
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
|
||||||
|
});
|
||||||
|
|
||||||
// Batch size
|
// Batch size
|
||||||
let size = requests.len();
|
let size = requests.len();
|
||||||
|
|
|
@ -4,7 +4,7 @@ mod db;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
use batcher::Batcher;
|
use batcher::{Batcher, InferResponse};
|
||||||
use db::{Db, Entry};
|
use db::{Db, Entry};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
@ -64,5 +64,3 @@ pub(crate) struct GenerateRequest {
|
||||||
pub(crate) struct GeneratedText {
|
pub(crate) struct GeneratedText {
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) type GenerateResponse = Vec<GeneratedText>;
|
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
use bloom_inference_client::ShardedClient;
|
use bloom_inference_client::ShardedClient;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::time::Duration;
|
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
@ -16,8 +15,8 @@ struct Args {
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "32", long, env)]
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
#[clap(default_value = "5", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_time: u64,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
||||||
|
@ -36,19 +35,19 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_time,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
tracing_subscriber::fmt().compact().with_ansi(false).init();
|
||||||
|
|
||||||
if validation_workers == 1 {
|
if validation_workers == 1 {
|
||||||
panic!("validation_workers must be > 0");
|
panic!("validation_workers must be > 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
let max_waiting_time = Duration::from_secs(max_waiting_time);
|
|
||||||
|
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
//
|
//
|
||||||
|
@ -61,8 +60,6 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
// Instantiate sharded client from the master unix socket
|
||||||
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
|
@ -82,7 +79,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_time,
|
max_waiting_tokens,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
use crate::{
|
use crate::{Batcher, GenerateParameters, GenerateRequest, GeneratedText, Validation};
|
||||||
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
|
|
||||||
};
|
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::StatusCode;
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use bloom_inference_client::ShardedClient;
|
use bloom_inference_client::ShardedClient;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
@ -59,12 +57,21 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate method
|
/// Generate method
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(
|
||||||
|
skip(state),
|
||||||
|
fields(
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token
|
||||||
|
)
|
||||||
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
state: Extension<ServerState>,
|
state: Extension<ServerState>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
|
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||||
let start = Instant::now();
|
let start_time = Instant::now();
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||||
(
|
(
|
||||||
|
@ -84,19 +91,51 @@ async fn generate(
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let generated_text = state.batcher.infer(input_length, validated_request).await?;
|
let response = state.batcher.infer(input_length, validated_request).await?;
|
||||||
|
|
||||||
|
// Timings
|
||||||
|
let total_time = start_time.elapsed();
|
||||||
|
let validation_time = response.queued - start_time;
|
||||||
|
let queue_time = response.start - response.queued;
|
||||||
|
let inference_time = response.end - response.start;
|
||||||
|
let time_per_token = inference_time / req.parameters.max_new_tokens;
|
||||||
|
|
||||||
|
// Headers
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"x-total-time",
|
||||||
|
total_time.as_millis().to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
headers.insert(
|
||||||
|
"x-validation-time",
|
||||||
|
validation_time.as_millis().to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
headers.insert(
|
||||||
|
"x-queue-time",
|
||||||
|
queue_time.as_millis().to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
headers.insert(
|
||||||
|
"x-inference-time",
|
||||||
|
inference_time.as_millis().to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
headers.insert(
|
||||||
|
"x-time-per-token",
|
||||||
|
time_per_token.as_millis().to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
tracing::Span::current().record("total_time", format!("{:?}", total_time));
|
||||||
tracing::Span::current().record(
|
tracing::Span::current().record("validation_time", format!("{:?}", validation_time));
|
||||||
"time_per_token",
|
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
||||||
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
|
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
||||||
);
|
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
||||||
tracing::info!("response: {}", generated_text);
|
tracing::info!("Output: {}", response.output);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
let response = vec![GeneratedText { generated_text }];
|
let response = vec![GeneratedText {
|
||||||
Ok(Json(response))
|
generated_text: response.output,
|
||||||
|
}];
|
||||||
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Serving method
|
/// Serving method
|
||||||
|
@ -105,14 +144,14 @@ pub async fn run(
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_time: Duration,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
) {
|
) {
|
||||||
// Create state
|
// Create state
|
||||||
let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
|
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||||
let shared_state = ServerState {
|
let shared_state = ServerState {
|
||||||
validation,
|
validation,
|
||||||
|
|
|
@ -127,7 +127,10 @@ fn validation_worker(
|
||||||
|
|
||||||
if input_length > max_input_length {
|
if input_length > max_input_length {
|
||||||
response_tx
|
response_tx
|
||||||
.send(Err(ValidationError::InputLength(input_length, max_input_length)))
|
.send(Err(ValidationError::InputLength(
|
||||||
|
input_length,
|
||||||
|
max_input_length,
|
||||||
|
)))
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue