From 3b0c979efcccd8ca51f59f1f982bfbbc842d06c9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 30 Jun 2023 20:07:49 +0200 Subject: [PATCH] feat(router): arg validation (#519) --- launcher/src/main.rs | 4 ++-- router/src/main.rs | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8497f807..9d6cd4dd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -101,7 +101,7 @@ struct Args { /// for users. The larger this value, the longer prompt users can send which /// can impact the overall memory required to handle the load. /// Please note that some models have a finite range of sequence they can handle. - #[clap(default_value = "1000", long, env)] + #[clap(default_value = "1024", long, env)] max_input_length: usize, /// This is the most important value to set as it defines the "memory budget" @@ -112,7 +112,7 @@ struct Args { /// `1511` max_new_tokens. /// The larger this value, the larger amount each request will be in your RAM /// and the less effective batching can be. - #[clap(default_value = "1512", long, env)] + #[clap(default_value = "2048", long, env)] max_total_tokens: usize, /// This represents the ratio of waiting queries vs running queries where diff --git a/router/src/main.rs b/router/src/main.rs index 47d48e3f..f782be09 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -28,15 +28,15 @@ struct Args { max_best_of: usize, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, - #[clap(default_value = "1000", long, env)] + #[clap(default_value = "1024", long, env)] max_input_length: usize, - #[clap(default_value = "1512", long, env)] + #[clap(default_value = "2048", long, env)] max_total_tokens: usize, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, #[clap(default_value = "4096", long, env)] max_batch_prefill_tokens: u32, - #[clap(default_value = "32000", long, env)] + #[clap(default_value = "16000", long, env)] max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, @@ -97,8 +97,18 @@ fn main() -> Result<(), std::io::Error> { ngrok_password, } = args; + // Validate args + if max_input_length as u32 > max_batch_prefill_tokens { + panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")); + } + if max_batch_prefill_tokens > max_batch_total_tokens { + panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")); + } + if max_total_tokens as u32 > max_batch_total_tokens { + panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")); + } if validation_workers == 0 { - panic!("validation_workers must be > 0"); + panic!("`validation_workers` must be > 0"); } // CORS allowed origins