feat(router): add max_total_tokens and empty_input validation (#68)

closes #65
This commit is contained in:
OlivierDehaene 2023-02-15 21:56:59 +01:00 committed by GitHub
parent 68455353f5
commit 5437d49beb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 17 deletions

View File

@ -20,8 +20,12 @@ use tracing_subscriber::{EnvFilter, Layer};
struct Args { struct Args {
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1000", long, env)]
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(default_value = "32", long, env)] #[clap(default_value = "32", long, env)]
max_batch_size: usize, max_batch_size: usize,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
@ -46,7 +50,9 @@ fn main() -> Result<(), std::io::Error> {
// Pattern match configuration // Pattern match configuration
let Args { let Args {
max_concurrent_requests, max_concurrent_requests,
max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens,
max_batch_size, max_batch_size,
max_waiting_tokens, max_waiting_tokens,
port, port,
@ -92,7 +98,9 @@ fn main() -> Result<(), std::io::Error> {
// Run server // Run server
server::run( server::run(
max_concurrent_requests, max_concurrent_requests,
max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens,
max_batch_size, max_batch_size,
max_waiting_tokens, max_waiting_tokens,
sharded_client, sharded_client,

View File

@ -28,7 +28,7 @@ use utoipa_swagger_ui::SwaggerUi;
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check. // be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy. // What we should do instead is check if the gRPC channels are still healthy.
// Send a small inference request // Send a small inference request
infer infer
@ -291,7 +291,9 @@ async fn generate_stream(
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize,
max_batch_size: usize, max_batch_size: usize,
max_waiting_tokens: usize, max_waiting_tokens: usize,
client: ShardedClient, client: ShardedClient,
@ -333,7 +335,13 @@ pub async fn run(
struct ApiDoc; struct ApiDoc;
// Create state // Create state
let validation = Validation::new(validation_workers, tokenizer, max_input_length); let validation = Validation::new(
validation_workers,
tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
);
let infer = Infer::new( let infer = Infer::new(
client, client,
validation, validation,

View File

@ -1,3 +1,4 @@
use crate::validation::ValidationError::EmptyInput;
/// Payload validation logic /// Payload validation logic
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng; use rand::rngs::ThreadRng;
@ -8,9 +9,6 @@ use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span}; use tracing::{instrument, Span};
const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4;
/// Validation /// Validation
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Validation { pub struct Validation {
@ -19,7 +17,13 @@ pub struct Validation {
} }
impl Validation { impl Validation {
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self { pub(crate) fn new(
workers: usize,
tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
) -> Self {
// Create channel // Create channel
let (validation_sender, validation_receiver) = mpsc::channel(128); let (validation_sender, validation_receiver) = mpsc::channel(128);
@ -27,7 +31,9 @@ impl Validation {
tokio::spawn(validation_task( tokio::spawn(validation_task(
workers, workers,
tokenizer, tokenizer,
max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens,
validation_receiver, validation_receiver,
)); ));
@ -61,7 +67,9 @@ impl Validation {
async fn validation_task( async fn validation_task(
workers: usize, workers: usize,
tokenizer: Tokenizer, tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize,
mut receiver: mpsc::Receiver<ValidationRequest>, mut receiver: mpsc::Receiver<ValidationRequest>,
) { ) {
let mut workers_senders = Vec::with_capacity(workers); let mut workers_senders = Vec::with_capacity(workers);
@ -75,7 +83,13 @@ async fn validation_task(
// Spawn worker // Spawn worker
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
validation_worker(tokenizer_clone, max_input_length, worker_receiver) validation_worker(
tokenizer_clone,
max_stop_sequences,
max_input_length,
max_total_tokens,
worker_receiver,
)
}); });
} }
@ -95,7 +109,9 @@ async fn validation_task(
/// the tokenizer /// the tokenizer
fn validation_worker( fn validation_worker(
tokenizer: Tokenizer, tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize,
mut receiver: mpsc::Receiver<ValidationRequest>, mut receiver: mpsc::Receiver<ValidationRequest>,
) { ) {
// Seed rng // Seed rng
@ -106,7 +122,15 @@ fn validation_worker(
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send( .send(
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| { validate(
request,
&tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
&mut rng,
)
.map_err(|err| {
tracing::error!("{err}"); tracing::error!("{err}");
err err
}), }),
@ -119,7 +143,9 @@ fn validation_worker(
fn validate( fn validate(
request: GenerateRequest, request: GenerateRequest,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize,
rng: &mut ThreadRng, rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> { ) -> Result<ValidGenerateRequest, ValidationError> {
let GenerateParameters { let GenerateParameters {
@ -161,13 +187,13 @@ fn validate(
} }
}?; }?;
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS { if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); return Err(ValidationError::MaxNewTokens);
} }
if stop_sequences.len() > MAX_STOP_SEQUENCES { if stop_sequences.len() > max_stop_sequences {
return Err(ValidationError::StopSequence( return Err(ValidationError::StopSequence(
MAX_STOP_SEQUENCES, max_stop_sequences,
stop_sequences.len(), stop_sequences.len(),
)); ));
} }
@ -178,13 +204,24 @@ fn validate(
Some(seed) => seed, Some(seed) => seed,
}; };
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
}
// Get the number of tokens in the input // Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) { match tokenizer.encode(request.inputs.clone(), true) {
Ok(encoding) => { Ok(encoding) => {
let input_length = encoding.len(); let input_length = encoding.len();
let total_tokens = input_length + max_new_tokens as usize;
if input_length > max_input_length { if input_length > max_input_length {
Err(ValidationError::InputLength(input_length, max_input_length)) Err(ValidationError::InputLength(max_input_length, input_length))
} else if total_tokens > max_total_tokens {
Err(ValidationError::MaxTotalTokens(
max_total_tokens,
input_length,
max_new_tokens,
))
} else { } else {
// Return ValidGenerateRequest // Return ValidGenerateRequest
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
@ -236,10 +273,14 @@ pub enum ValidationError {
TopP, TopP,
#[error("top_k must be strictly positive")] #[error("top_k must be strictly positive")]
TopK, TopK,
#[error("max_new_tokens must be strictly positive and <= {0}")] #[error("max_new_tokens must be strictly positive")]
MaxNewTokens(u32), MaxNewTokens,
#[error("inputs must have less than {1} tokens. Given: {0}")] #[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
MaxTotalTokens(usize, usize, u32),
#[error("inputs must have less than {0} tokens. Given: {1}")]
InputLength(usize, usize), InputLength(usize, usize),
#[error("inputs cannot be empty")]
EmptyInput,
#[error("stop supports up to {0} stop sequences. Given: {1}")] #[error("stop supports up to {0} stop sequences. Given: {1}")]
StopSequence(usize, usize), StopSequence(usize, usize),
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]