feat(router): add max_total_tokens and empty_input validation (#68)
closes #65
This commit is contained in:
parent
68455353f5
commit
5437d49beb
|
@ -20,8 +20,12 @@ use tracing_subscriber::{EnvFilter, Layer};
|
||||||
struct Args {
|
struct Args {
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
#[clap(default_value = "1000", long, env)]
|
#[clap(default_value = "1000", long, env)]
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
#[clap(default_value = "1512", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "32", long, env)]
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
|
@ -46,7 +50,9 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
|
@ -92,7 +98,9 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
|
|
|
@ -28,7 +28,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
||||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead if check if the gRPC channels are still healthy.
|
// What we should do instead is check if the gRPC channels are still healthy.
|
||||||
|
|
||||||
// Send a small inference request
|
// Send a small inference request
|
||||||
infer
|
infer
|
||||||
|
@ -291,7 +291,9 @@ async fn generate_stream(
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
|
@ -333,7 +335,13 @@ pub async fn run(
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(
|
||||||
|
validation_workers,
|
||||||
|
tokenizer,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::validation::ValidationError::EmptyInput;
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::rngs::ThreadRng;
|
use rand::rngs::ThreadRng;
|
||||||
|
@ -8,9 +9,6 @@ use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
|
|
||||||
const MAX_MAX_NEW_TOKENS: u32 = 512;
|
|
||||||
const MAX_STOP_SEQUENCES: usize = 4;
|
|
||||||
|
|
||||||
/// Validation
|
/// Validation
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Validation {
|
pub struct Validation {
|
||||||
|
@ -19,7 +17,13 @@ pub struct Validation {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
|
pub(crate) fn new(
|
||||||
|
workers: usize,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
||||||
|
|
||||||
|
@ -27,7 +31,9 @@ impl Validation {
|
||||||
tokio::spawn(validation_task(
|
tokio::spawn(validation_task(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
validation_receiver,
|
validation_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -61,7 +67,9 @@ impl Validation {
|
||||||
async fn validation_task(
|
async fn validation_task(
|
||||||
workers: usize,
|
workers: usize,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||||
) {
|
) {
|
||||||
let mut workers_senders = Vec::with_capacity(workers);
|
let mut workers_senders = Vec::with_capacity(workers);
|
||||||
|
@ -75,7 +83,13 @@ async fn validation_task(
|
||||||
|
|
||||||
// Spawn worker
|
// Spawn worker
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
|
validation_worker(
|
||||||
|
tokenizer_clone,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
worker_receiver,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +109,9 @@ async fn validation_task(
|
||||||
/// the tokenizer
|
/// the tokenizer
|
||||||
fn validation_worker(
|
fn validation_worker(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||||
) {
|
) {
|
||||||
// Seed rng
|
// Seed rng
|
||||||
|
@ -106,7 +122,15 @@ fn validation_worker(
|
||||||
parent_span.in_scope(|| {
|
parent_span.in_scope(|| {
|
||||||
response_tx
|
response_tx
|
||||||
.send(
|
.send(
|
||||||
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| {
|
validate(
|
||||||
|
request,
|
||||||
|
&tokenizer,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
&mut rng,
|
||||||
|
)
|
||||||
|
.map_err(|err| {
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
}),
|
}),
|
||||||
|
@ -119,7 +143,9 @@ fn validation_worker(
|
||||||
fn validate(
|
fn validate(
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
rng: &mut ThreadRng,
|
rng: &mut ThreadRng,
|
||||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
let GenerateParameters {
|
let GenerateParameters {
|
||||||
|
@ -161,13 +187,13 @@ fn validate(
|
||||||
}
|
}
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
if max_new_tokens == 0 {
|
||||||
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
return Err(ValidationError::MaxNewTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
if stop_sequences.len() > MAX_STOP_SEQUENCES {
|
if stop_sequences.len() > max_stop_sequences {
|
||||||
return Err(ValidationError::StopSequence(
|
return Err(ValidationError::StopSequence(
|
||||||
MAX_STOP_SEQUENCES,
|
max_stop_sequences,
|
||||||
stop_sequences.len(),
|
stop_sequences.len(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
@ -178,13 +204,24 @@ fn validate(
|
||||||
Some(seed) => seed,
|
Some(seed) => seed,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Check if inputs is empty
|
||||||
|
if request.inputs.is_empty() {
|
||||||
|
return Err(EmptyInput);
|
||||||
|
}
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
match tokenizer.encode(request.inputs.clone(), true) {
|
match tokenizer.encode(request.inputs.clone(), true) {
|
||||||
Ok(encoding) => {
|
Ok(encoding) => {
|
||||||
let input_length = encoding.len();
|
let input_length = encoding.len();
|
||||||
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
if input_length > max_input_length {
|
if input_length > max_input_length {
|
||||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
Err(ValidationError::InputLength(max_input_length, input_length))
|
||||||
|
} else if total_tokens > max_total_tokens {
|
||||||
|
Err(ValidationError::MaxTotalTokens(
|
||||||
|
max_total_tokens,
|
||||||
|
input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
// Return ValidGenerateRequest
|
// Return ValidGenerateRequest
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
|
@ -236,10 +273,14 @@ pub enum ValidationError {
|
||||||
TopP,
|
TopP,
|
||||||
#[error("top_k must be strictly positive")]
|
#[error("top_k must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
#[error("max_new_tokens must be strictly positive and <= {0}")]
|
#[error("max_new_tokens must be strictly positive")]
|
||||||
MaxNewTokens(u32),
|
MaxNewTokens,
|
||||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
|
||||||
|
MaxTotalTokens(usize, usize, u32),
|
||||||
|
#[error("inputs must have less than {0} tokens. Given: {1}")]
|
||||||
InputLength(usize, usize),
|
InputLength(usize, usize),
|
||||||
|
#[error("inputs cannot be empty")]
|
||||||
|
EmptyInput,
|
||||||
#[error("stop supports up to {0} stop sequences. Given: {1}")]
|
#[error("stop supports up to {0} stop sequences. Given: {1}")]
|
||||||
StopSequence(usize, usize),
|
StopSequence(usize, usize),
|
||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
|
|
Loading…
Reference in New Issue