feat(router): arg validation (#519)

This commit is contained in:
OlivierDehaene 2023-06-30 20:07:49 +02:00 committed by GitHub
parent e74bd41e0f
commit 3b0c979efc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 6 deletions

View File

@ -101,7 +101,7 @@ struct Args {
/// for users. The larger this value, the longer prompt users can send which /// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load. /// can impact the overall memory required to handle the load.
/// Please note that some models have a finite range of sequence they can handle. /// Please note that some models have a finite range of sequence they can handle.
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_length: usize,
/// This is the most important value to set as it defines the "memory budget" /// This is the most important value to set as it defines the "memory budget"
@ -112,7 +112,7 @@ struct Args {
/// `1511` max_new_tokens. /// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM /// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be. /// and the less effective batching can be.
#[clap(default_value = "1512", long, env)] #[clap(default_value = "2048", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
/// This represents the ratio of waiting queries vs running queries where /// This represents the ratio of waiting queries vs running queries where

View File

@ -28,15 +28,15 @@ struct Args {
max_best_of: usize, max_best_of: usize,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "1512", long, env)] #[clap(default_value = "2048", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)] #[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
#[clap(default_value = "32000", long, env)] #[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
@ -97,8 +97,18 @@ fn main() -> Result<(), std::io::Error> {
ngrok_password, ngrok_password,
} = args; } = args;
// Validate args
if max_input_length as u32 > max_batch_prefill_tokens {
panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"));
}
if max_total_tokens as u32 > max_batch_total_tokens {
panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"));
}
if validation_workers == 0 { if validation_workers == 0 {
panic!("validation_workers must be > 0"); panic!("`validation_workers` must be > 0");
} }
// CORS allowed origins // CORS allowed origins