feat(router): add max_total_tokens and empty_input validation (#68)
closes #65
This commit is contained in:
parent
68455353f5
commit
5437d49beb
|
@ -20,8 +20,12 @@ use tracing_subscriber::{EnvFilter, Layer};
|
|||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "1000", long, env)]
|
||||
max_input_length: usize,
|
||||
#[clap(default_value = "1512", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
|
@ -46,7 +50,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
|
@ -92,7 +98,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||
// Run server
|
||||
server::run(
|
||||
max_concurrent_requests,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_waiting_tokens,
|
||||
sharded_client,
|
||||
|
|
|
@ -28,7 +28,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
|||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead if check if the gRPC channels are still healthy.
|
||||
// What we should do instead is check if the gRPC channels are still healthy.
|
||||
|
||||
// Send a small inference request
|
||||
infer
|
||||
|
@ -291,7 +291,9 @@ async fn generate_stream(
|
|||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
max_concurrent_requests: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: usize,
|
||||
max_waiting_tokens: usize,
|
||||
client: ShardedClient,
|
||||
|
@ -333,7 +335,13 @@ pub async fn run(
|
|||
struct ApiDoc;
|
||||
|
||||
// Create state
|
||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::validation::ValidationError::EmptyInput;
|
||||
/// Payload validation logic
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::rngs::ThreadRng;
|
||||
|
@ -8,9 +9,6 @@ use tokenizers::tokenizer::Tokenizer;
|
|||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{instrument, Span};
|
||||
|
||||
const MAX_MAX_NEW_TOKENS: u32 = 512;
|
||||
const MAX_STOP_SEQUENCES: usize = 4;
|
||||
|
||||
/// Validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
|
@ -19,7 +17,13 @@ pub struct Validation {
|
|||
}
|
||||
|
||||
impl Validation {
|
||||
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
|
||||
pub(crate) fn new(
|
||||
workers: usize,
|
||||
tokenizer: Tokenizer,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
||||
|
||||
|
@ -27,7 +31,9 @@ impl Validation {
|
|||
tokio::spawn(validation_task(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
validation_receiver,
|
||||
));
|
||||
|
||||
|
@ -61,7 +67,9 @@ impl Validation {
|
|||
async fn validation_task(
|
||||
workers: usize,
|
||||
tokenizer: Tokenizer,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||
) {
|
||||
let mut workers_senders = Vec::with_capacity(workers);
|
||||
|
@ -75,7 +83,13 @@ async fn validation_task(
|
|||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
|
||||
validation_worker(
|
||||
tokenizer_clone,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
worker_receiver,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -95,7 +109,9 @@ async fn validation_task(
|
|||
/// the tokenizer
|
||||
fn validation_worker(
|
||||
tokenizer: Tokenizer,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||
) {
|
||||
// Seed rng
|
||||
|
@ -106,7 +122,15 @@ fn validation_worker(
|
|||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(
|
||||
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| {
|
||||
validate(
|
||||
request,
|
||||
&tokenizer,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
&mut rng,
|
||||
)
|
||||
.map_err(|err| {
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}),
|
||||
|
@ -119,7 +143,9 @@ fn validation_worker(
|
|||
fn validate(
|
||||
request: GenerateRequest,
|
||||
tokenizer: &Tokenizer,
|
||||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
rng: &mut ThreadRng,
|
||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||
let GenerateParameters {
|
||||
|
@ -161,13 +187,13 @@ fn validate(
|
|||
}
|
||||
}?;
|
||||
|
||||
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
||||
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
||||
if max_new_tokens == 0 {
|
||||
return Err(ValidationError::MaxNewTokens);
|
||||
}
|
||||
|
||||
if stop_sequences.len() > MAX_STOP_SEQUENCES {
|
||||
if stop_sequences.len() > max_stop_sequences {
|
||||
return Err(ValidationError::StopSequence(
|
||||
MAX_STOP_SEQUENCES,
|
||||
max_stop_sequences,
|
||||
stop_sequences.len(),
|
||||
));
|
||||
}
|
||||
|
@ -178,13 +204,24 @@ fn validate(
|
|||
Some(seed) => seed,
|
||||
};
|
||||
|
||||
// Check if inputs is empty
|
||||
if request.inputs.is_empty() {
|
||||
return Err(EmptyInput);
|
||||
}
|
||||
|
||||
// Get the number of tokens in the input
|
||||
match tokenizer.encode(request.inputs.clone(), true) {
|
||||
Ok(encoding) => {
|
||||
let input_length = encoding.len();
|
||||
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
if input_length > max_input_length {
|
||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
||||
Err(ValidationError::InputLength(max_input_length, input_length))
|
||||
} else if total_tokens > max_total_tokens {
|
||||
Err(ValidationError::MaxTotalTokens(
|
||||
max_total_tokens,
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
))
|
||||
} else {
|
||||
// Return ValidGenerateRequest
|
||||
let parameters = NextTokenChooserParameters {
|
||||
|
@ -236,10 +273,14 @@ pub enum ValidationError {
|
|||
TopP,
|
||||
#[error("top_k must be strictly positive")]
|
||||
TopK,
|
||||
#[error("max_new_tokens must be strictly positive and <= {0}")]
|
||||
MaxNewTokens(u32),
|
||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
||||
#[error("max_new_tokens must be strictly positive")]
|
||||
MaxNewTokens,
|
||||
#[error("input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens")]
|
||||
MaxTotalTokens(usize, usize, u32),
|
||||
#[error("inputs must have less than {0} tokens. Given: {1}")]
|
||||
InputLength(usize, usize),
|
||||
#[error("inputs cannot be empty")]
|
||||
EmptyInput,
|
||||
#[error("stop supports up to {0} stop sequences. Given: {1}")]
|
||||
StopSequence(usize, usize),
|
||||
#[error("tokenizer error {0}")]
|
||||
|
|
Loading…
Reference in New Issue