From 60472f9d2b03822e9e771d9a4a4885e2496c5537 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 3 Jan 2023 01:41:22 -0800 Subject: [PATCH] feat(router): Add const parameters to validation logic (#15) I noticed some opportunity to collapse some of the logic, in case you are interested. --- router/src/db.rs | 12 ++--- router/src/validation.rs | 109 +++++++++++++++++---------------------- 2 files changed, 51 insertions(+), 70 deletions(-) diff --git a/router/src/db.rs b/router/src/db.rs index df9f2b8e..1d7df627 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -71,12 +71,8 @@ impl State { id: *id, inputs: entry.request.inputs.clone(), input_length: entry.input_length as u32, - parameters: Some(NextTokenChooserParameters::from( - entry.request.parameters.clone(), - )), - stopping_parameters: Some(StoppingCriteriaParameters::from( - entry.request.parameters.clone(), - )), + parameters: Some((&entry.request.parameters).into()), + stopping_parameters: Some(entry.request.parameters.clone().into()), }); ids.push(*id); @@ -162,8 +158,8 @@ impl Db { } } -impl From for NextTokenChooserParameters { - fn from(parameters: GenerateParameters) -> Self { +impl From<&GenerateParameters> for NextTokenChooserParameters { + fn from(parameters: &GenerateParameters) -> Self { Self { temperature: parameters.temperature, top_k: parameters.top_k as u32, diff --git a/router/src/validation.rs b/router/src/validation.rs index f6da1913..4a9d0c23 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -4,12 +4,11 @@ use axum::http::StatusCode; use axum::Json; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; -use tokenizers::{ - DecoderWrapper, ModelWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper, - TokenizerImpl, -}; use tokio::sync::{mpsc, oneshot}; +const MAX_MAX_NEW_TOKENS: u32 = 512; +const MAX_STOP_SEQUENCES: usize = 4; + /// Validation #[derive(Debug, Clone)] pub struct Validation { @@ -63,7 +62,7 @@ async fn validation_task( // Create workers for _ in 0..workers { - let tokenizer_clone = tokenizer.clone(); + let tokenizer_clone: Tokenizer = tokenizer.clone().into(); // Create channel to communicate with worker let (worker_sender, worker_receiver) = mpsc::channel(workers); workers_senders.push(worker_sender); @@ -89,68 +88,54 @@ async fn validation_task( /// Check the parameters inside the payload and get the number of tokens inside the input using /// the tokenizer fn validation_worker( - tokenizer: TokenizerImpl< - ModelWrapper, - NormalizerWrapper, - PreTokenizerWrapper, - PostProcessorWrapper, - DecoderWrapper, - >, + tokenizer: Tokenizer, max_input_length: usize, mut receiver: mpsc::Receiver, ) { // Loop over requests while let Some((request, response_tx)) = receiver.blocking_recv() { - if request.parameters.temperature <= 0.0 { - response_tx - .send(Err(ValidationError::Temperature)) - .unwrap_or(()); - continue; - } - if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 { - response_tx.send(Err(ValidationError::TopP)).unwrap_or(()); - continue; - } - if request.parameters.top_k < 0 { - response_tx.send(Err(ValidationError::TopK)).unwrap_or(()); - continue; - } - if request.parameters.max_new_tokens > 512 { - response_tx - .send(Err(ValidationError::MaxNewTokens)) - .unwrap_or(()); - continue; - } - if request.parameters.stop.len() > 4 { - response_tx - .send(Err(ValidationError::StopSequence( - request.parameters.stop.len(), - ))) - .unwrap_or(()); - continue; - } + response_tx.send(validate(request, &tokenizer, max_input_length)).unwrap_or(()) + } +} - // Get the number of tokens in the input - match tokenizer.encode(request.inputs.clone(), true) { - Ok(inputs) => { - let input_length = inputs.len(); +fn validate( + request: GenerateRequest, + tokenizer: &Tokenizer, + max_input_length: usize, +) -> Result<(usize, GenerateRequest), ValidationError> { + if request.parameters.temperature <= 0.0 { + return Err(ValidationError::Temperature); + } + if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 { + return Err(ValidationError::TopP); + } + if request.parameters.top_k < 0 { + return Err(ValidationError::TopK); + } + if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS { + return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); + } + if request.parameters.stop.len() > MAX_STOP_SEQUENCES { + return Err(ValidationError::StopSequence( + MAX_STOP_SEQUENCES, request.parameters.stop.len(), + )) + } - if input_length > max_input_length { - response_tx - .send(Err(ValidationError::InputLength( - input_length, - max_input_length, - ))) - .unwrap_or(()); - continue; - } + // Get the number of tokens in the input + match tokenizer.encode(request.inputs.clone(), true) { + Ok(inputs) => { + let input_length = inputs.len(); - response_tx.send(Ok((input_length, request))).unwrap_or(()); + if input_length > max_input_length { + Err(ValidationError::InputLength( + input_length, + max_input_length, + )) + } else { + Ok((input_length, request)) } - Err(err) => response_tx - .send(Err(ValidationError::Tokenizer(err.to_string()))) - .unwrap_or(()), - }; + }, + Err(err) => Err(ValidationError::Tokenizer(err.to_string())), } } @@ -167,12 +152,12 @@ pub enum ValidationError { TopP, #[error("top_k must be strictly positive")] TopK, - #[error("max_new_tokens must be <= 512")] - MaxNewTokens, + #[error("max_new_tokens must be <= {0}")] + MaxNewTokens(u32), #[error("inputs must have less than {1} tokens. Given: {0}")] InputLength(usize, usize), - #[error("stop supports up to 4 stop sequences. Given: {0}")] - StopSequence(usize), + #[error("stop supports up to {0} stop sequences. Given: {1}")] + StopSequence(usize, usize), #[error("tokenizer error {0}")] Tokenizer(String), }