From 5437d49bebe1338197a2630e9dc19b88664ad721 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 15 Feb 2023 21:56:59 +0100 Subject: [PATCH] feat(router): add max_total_tokens and empty_input validation (#68) closes #65 --- router/src/main.rs | 8 +++++ router/src/server.rs | 12 +++++-- router/src/validation.rs | 71 +++++++++++++++++++++++++++++++--------- 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 881d94fb..5ababa4b 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -20,8 +20,12 @@ use tracing_subscriber::{EnvFilter, Layer}; struct Args { #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, #[clap(default_value = "1000", long, env)] max_input_length: usize, + #[clap(default_value = "1512", long, env)] + max_total_tokens: usize, #[clap(default_value = "32", long, env)] max_batch_size: usize, #[clap(default_value = "20", long, env)] @@ -46,7 +50,9 @@ fn main() -> Result<(), std::io::Error> { // Pattern match configuration let Args { max_concurrent_requests, + max_stop_sequences, max_input_length, + max_total_tokens, max_batch_size, max_waiting_tokens, port, @@ -92,7 +98,9 @@ fn main() -> Result<(), std::io::Error> { // Run server server::run( max_concurrent_requests, + max_stop_sequences, max_input_length, + max_total_tokens, max_batch_size, max_waiting_tokens, sharded_client, diff --git a/router/src/server.rs b/router/src/server.rs index 432586bb..19af1e78 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -28,7 +28,7 @@ use utoipa_swagger_ui::SwaggerUi; async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { // TODO: while this is the best health check we can do, it is a bit on the heavy side and might // be a bit too slow for a health check. - // What we should do instead if check if the gRPC channels are still healthy. + // What we should do instead is check if the gRPC channels are still healthy. // Send a small inference request infer @@ -291,7 +291,9 @@ async fn generate_stream( #[allow(clippy::too_many_arguments)] pub async fn run( max_concurrent_requests: usize, + max_stop_sequences: usize, max_input_length: usize, + max_total_tokens: usize, max_batch_size: usize, max_waiting_tokens: usize, client: ShardedClient, @@ -333,7 +335,13 @@ pub async fn run( struct ApiDoc; // Create state - let validation = Validation::new(validation_workers, tokenizer, max_input_length); + let validation = Validation::new( + validation_workers, + tokenizer, + max_stop_sequences, + max_input_length, + max_total_tokens, + ); let infer = Infer::new( client, validation, diff --git a/router/src/validation.rs b/router/src/validation.rs index aa1c1d23..50d090cd 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,3 +1,4 @@ +use crate::validation::ValidationError::EmptyInput; /// Payload validation logic use crate::{GenerateParameters, GenerateRequest}; use rand::rngs::ThreadRng; @@ -8,9 +9,6 @@ use tokenizers::tokenizer::Tokenizer; use tokio::sync::{mpsc, oneshot}; use tracing::{instrument, Span}; -const MAX_MAX_NEW_TOKENS: u32 = 512; -const MAX_STOP_SEQUENCES: usize = 4; - /// Validation #[derive(Debug, Clone)] pub struct Validation { @@ -19,7 +17,13 @@ pub struct Validation { } impl Validation { - pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self { + pub(crate) fn new( + workers: usize, + tokenizer: Tokenizer, + max_stop_sequences: usize, + max_input_length: usize, + max_total_tokens: usize, + ) -> Self { // Create channel let (validation_sender, validation_receiver) = mpsc::channel(128); @@ -27,7 +31,9 @@ impl Validation { tokio::spawn(validation_task( workers, tokenizer, + max_stop_sequences, max_input_length, + max_total_tokens, validation_receiver, )); @@ -61,7 +67,9 @@ impl Validation { async fn validation_task( workers: usize, tokenizer: Tokenizer, + max_stop_sequences: usize, max_input_length: usize, + max_total_tokens: usize, mut receiver: mpsc::Receiver, ) { let mut workers_senders = Vec::with_capacity(workers); @@ -75,7 +83,13 @@ async fn validation_task( // Spawn worker tokio::task::spawn_blocking(move || { - validation_worker(tokenizer_clone, max_input_length, worker_receiver) + validation_worker( + tokenizer_clone, + max_stop_sequences, + max_input_length, + max_total_tokens, + worker_receiver, + ) }); } @@ -95,7 +109,9 @@ async fn validation_task( /// the tokenizer fn validation_worker( tokenizer: Tokenizer, + max_stop_sequences: usize, max_input_length: usize, + max_total_tokens: usize, mut receiver: mpsc::Receiver, ) { // Seed rng @@ -106,7 +122,15 @@ fn validation_worker( parent_span.in_scope(|| { response_tx .send( - validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| { + validate( + request, + &tokenizer, + max_stop_sequences, + max_input_length, + max_total_tokens, + &mut rng, + ) + .map_err(|err| { tracing::error!("{err}"); err }), @@ -119,7 +143,9 @@ fn validation_worker( fn validate( request: GenerateRequest, tokenizer: &Tokenizer, + max_stop_sequences: usize, max_input_length: usize, + max_total_tokens: usize, rng: &mut ThreadRng, ) -> Result { let GenerateParameters { @@ -161,13 +187,13 @@ fn validate( } }?; - if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS { - return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); + if max_new_tokens == 0 { + return Err(ValidationError::MaxNewTokens); } - if stop_sequences.len() > MAX_STOP_SEQUENCES { + if stop_sequences.len() > max_stop_sequences { return Err(ValidationError::StopSequence( - MAX_STOP_SEQUENCES, + max_stop_sequences, stop_sequences.len(), )); } @@ -178,13 +204,24 @@ fn validate( Some(seed) => seed, }; + // Check if inputs is empty + if request.inputs.is_empty() { + return Err(EmptyInput); + } + // Get the number of tokens in the input match tokenizer.encode(request.inputs.clone(), true) { Ok(encoding) => { let input_length = encoding.len(); - + let total_tokens = input_length + max_new_tokens as usize; if input_length > max_input_length { - Err(ValidationError::InputLength(input_length, max_input_length)) + Err(ValidationError::InputLength(max_input_length, input_length)) + } else if total_tokens > max_total_tokens { + Err(ValidationError::MaxTotalTokens( + max_total_tokens, + input_length, + max_new_tokens, + )) } else { // Return ValidGenerateRequest let parameters = NextTokenChooserParameters { @@ -236,10 +273,14 @@ pub enum ValidationError { TopP, #[error("top_k must be strictly positive")] TopK, - #[error("max_new_tokens must be strictly positive and <= {0}")] - MaxNewTokens(u32), - #[error("inputs must have less than {1} tokens. Given: {0}")] + #[error("max_new_tokens must be strictly positive")] + MaxNewTokens, + #[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")] + MaxTotalTokens(usize, usize, u32), + #[error("inputs must have less than {0} tokens. Given: {1}")] InputLength(usize, usize), + #[error("inputs cannot be empty")] + EmptyInput, #[error("stop supports up to {0} stop sequences. Given: {1}")] StopSequence(usize, usize), #[error("tokenizer error {0}")]