feat(router): add max_total_tokens and empty_input validation (#68)

closes #65
This commit is contained in:
OlivierDehaene 2023-02-15 21:56:59 +01:00 committed by GitHub
parent 68455353f5
commit 5437d49beb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 17 deletions

View File

@ -20,8 +20,12 @@ use tracing_subscriber::{EnvFilter, Layer};
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "20", long, env)]
@ -46,7 +50,9 @@ fn main() -> Result<(), std::io::Error> {
// Pattern match configuration
let Args {
max_concurrent_requests,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_waiting_tokens,
port,
@ -92,7 +98,9 @@ fn main() -> Result<(), std::io::Error> {
// Run server
server::run(
max_concurrent_requests,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_waiting_tokens,
sharded_client,

View File

@ -28,7 +28,7 @@ use utoipa_swagger_ui::SwaggerUi;
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy.
// What we should do instead is check if the gRPC channels are still healthy.
// Send a small inference request
infer
@ -291,7 +291,9 @@ async fn generate_stream(
#[allow(clippy::too_many_arguments)]
pub async fn run(
max_concurrent_requests: usize,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
max_batch_size: usize,
max_waiting_tokens: usize,
client: ShardedClient,
@ -333,7 +335,13 @@ pub async fn run(
struct ApiDoc;
// Create state
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let validation = Validation::new(
validation_workers,
tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
);
let infer = Infer::new(
client,
validation,

View File

@ -1,3 +1,4 @@
use crate::validation::ValidationError::EmptyInput;
/// Payload validation logic
use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng;
@ -8,9 +9,6 @@ use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span};
const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4;
/// Validation
#[derive(Debug, Clone)]
pub struct Validation {
@ -19,7 +17,13 @@ pub struct Validation {
}
impl Validation {
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
pub(crate) fn new(
workers: usize,
tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
) -> Self {
// Create channel
let (validation_sender, validation_receiver) = mpsc::channel(128);
@ -27,7 +31,9 @@ impl Validation {
tokio::spawn(validation_task(
workers,
tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
validation_receiver,
));
@ -61,7 +67,9 @@ impl Validation {
async fn validation_task(
workers: usize,
tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
let mut workers_senders = Vec::with_capacity(workers);
@ -75,7 +83,13 @@ async fn validation_task(
// Spawn worker
tokio::task::spawn_blocking(move || {
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
validation_worker(
tokenizer_clone,
max_stop_sequences,
max_input_length,
max_total_tokens,
worker_receiver,
)
});
}
@ -95,7 +109,9 @@ async fn validation_task(
/// the tokenizer
fn validation_worker(
tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
// Seed rng
@ -106,7 +122,15 @@ fn validation_worker(
parent_span.in_scope(|| {
response_tx
.send(
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| {
validate(
request,
&tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
&mut rng,
)
.map_err(|err| {
tracing::error!("{err}");
err
}),
@ -119,7 +143,9 @@ fn validation_worker(
fn validate(
request: GenerateRequest,
tokenizer: &Tokenizer,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> {
let GenerateParameters {
@ -161,13 +187,13 @@ fn validate(
}
}?;
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens);
}
if stop_sequences.len() > MAX_STOP_SEQUENCES {
if stop_sequences.len() > max_stop_sequences {
return Err(ValidationError::StopSequence(
MAX_STOP_SEQUENCES,
max_stop_sequences,
stop_sequences.len(),
));
}
@ -178,13 +204,24 @@ fn validate(
Some(seed) => seed,
};
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
}
// Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) {
Ok(encoding) => {
let input_length = encoding.len();
let total_tokens = input_length + max_new_tokens as usize;
if input_length > max_input_length {
Err(ValidationError::InputLength(input_length, max_input_length))
Err(ValidationError::InputLength(max_input_length, input_length))
} else if total_tokens > max_total_tokens {
Err(ValidationError::MaxTotalTokens(
max_total_tokens,
input_length,
max_new_tokens,
))
} else {
// Return ValidGenerateRequest
let parameters = NextTokenChooserParameters {
@ -236,10 +273,14 @@ pub enum ValidationError {
TopP,
#[error("top_k must be strictly positive")]
TopK,
#[error("max_new_tokens must be strictly positive and <= {0}")]
MaxNewTokens(u32),
#[error("inputs must have less than {1} tokens. Given: {0}")]
#[error("max_new_tokens must be strictly positive")]
MaxNewTokens,
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
MaxTotalTokens(usize, usize, u32),
#[error("inputs must have less than {0} tokens. Given: {1}")]
InputLength(usize, usize),
#[error("inputs cannot be empty")]
EmptyInput,
#[error("stop supports up to {0} stop sequences. Given: {1}")]
StopSequence(usize, usize),
#[error("tokenizer error {0}")]