From e1dc1681880d5f6dd5dbcd22b0c461888b0e69b0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 5 Feb 2024 14:29:32 +0100 Subject: [PATCH] Adding batch_dimension_flag (to be used for Neuron other forced padding targets). --- Cargo.lock | 38 ++++++++++++++- launcher/src/main.rs | 10 ++++ router/src/main.rs | 4 ++ router/src/server.rs | 2 + router/src/validation.rs | 102 ++++++++++++++++++++------------------- 5 files changed, 105 insertions(+), 51 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7fdf301a..3318f3b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2787,7 +2787,7 @@ dependencies = [ "tabled", "text-generation-client", "thiserror", - "tokenizers", + "tokenizers 0.14.1", "tokio", "tracing", "tracing-subscriber", @@ -2850,7 +2850,7 @@ dependencies = [ "serde_json", "text-generation-client", "thiserror", - "tokenizers", + "tokenizers 0.15.1", "tokio", "tokio-stream", "tower-http", @@ -2972,6 +2972,40 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "tokenizers" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812" +dependencies = [ + "aho-corasick", + "clap", + "derive_builder", + "esaxx-rs", + "getrandom", + "hf-hub", + "indicatif", + "itertools 0.11.0", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.7.5", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.35.1" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 054e546c..e6799cb3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -368,6 +368,12 @@ struct Args { #[clap(long, env)] ngrok_edge: Option, + /// Specific flag for hardware targets that do not support unpadded inference + /// For those we do not send the tokenizer to the router so that all the scheduling + /// assumes those pad tokens exist (and potentially even more). + #[clap(long, env)] + batch_dimension: bool, + /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may /// include a `chat_template`. If not provided, the default config will be used from the model hub. #[clap(long, env)] @@ -1034,6 +1040,10 @@ fn spawn_webserver( args.model_id, ]; + if args.batch_dimension{ + router_args.push("--batch-dimension".to_string()); + } + // Tokenizer config path if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { router_args.push("--tokenizer-config-path".to_string()); diff --git a/router/src/main.rs b/router/src/main.rs index 2a080468..87b5de5f 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -37,6 +37,8 @@ struct Args { max_input_length: usize, #[clap(default_value = "2048", long, env)] max_total_tokens: usize, + #[clap(long, env)] + batch_dimension: bool, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, #[clap(default_value = "4096", long, env)] @@ -87,6 +89,7 @@ async fn main() -> Result<(), RouterError> { max_top_n_tokens, max_input_length, max_total_tokens, + batch_dimension, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, @@ -340,6 +343,7 @@ async fn main() -> Result<(), RouterError> { max_top_n_tokens, max_input_length, max_total_tokens, + batch_dimension, waiting_served_ratio, max_batch_prefill_tokens, max_supported_batch_total_tokens, diff --git a/router/src/server.rs b/router/src/server.rs index b4d26158..42ddca68 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -758,6 +758,7 @@ pub async fn run( max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + batch_dimension: bool, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, @@ -833,6 +834,7 @@ pub async fn run( max_top_n_tokens, max_input_length, max_total_tokens, + batch_dimension ); let generation_health = Arc::new(AtomicBool::new(false)); let health_ext = Health::new(client.clone(), generation_health.clone()); diff --git a/router/src/validation.rs b/router/src/validation.rs index 750b98e5..3d83b10c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -19,6 +19,7 @@ pub struct Validation { max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + batched_dimension: bool, /// Channel to communicate with the background tokenization task sender: Option>, } @@ -32,6 +33,7 @@ impl Validation { max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + batched_dimension: bool, ) -> Self { // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { @@ -66,6 +68,7 @@ impl Validation { max_top_n_tokens, max_input_length, max_total_tokens, + batched_dimension } } @@ -103,61 +106,62 @@ impl Validation { ) -> Result<(String, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { - // Create response channel - let input_length = encoding.len(); + if self.batched_dimension{ + let input_length = encoding.len(); - // Get total tokens - let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { - max_new_tokens - } else { - self.max_total_tokens.saturating_sub(input_length) as u32 - }; - let total_tokens = input_length + max_new_tokens as usize; + // Get total tokens + let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { + max_new_tokens + } else { + self.max_total_tokens.saturating_sub(input_length) as u32 + }; + let total_tokens = input_length + max_new_tokens as usize; - // Validate MaxTotalTokens - if total_tokens > self.max_total_tokens { - return Err(ValidationError::MaxTotalTokens( - self.max_total_tokens, - input_length, - max_new_tokens, - )); + // Validate MaxTotalTokens + if total_tokens > self.max_total_tokens { + return Err(ValidationError::MaxTotalTokens( + self.max_total_tokens, + input_length, + max_new_tokens, + )); + } + + // Validate InputLength + if input_length > self.max_input_length { + return Err(ValidationError::InputLength( + self.max_input_length, + input_length, + )); + } + + // + metrics::histogram!("tgi_request_input_length", input_length as f64); + return Ok((inputs, input_length, max_new_tokens)); } - - // Validate InputLength - if input_length > self.max_input_length { - return Err(ValidationError::InputLength( - self.max_input_length, - input_length, - )); - } - - metrics::histogram!("tgi_request_input_length", input_length as f64); - Ok((inputs, input_length, max_new_tokens)) } - // Return inputs without validation - else { - // In this case, we don't know the real length in tokens of the inputs - // However, the inputs will be truncated by the python servers - // We make sure that truncate + max_new_tokens <= self.max_total_tokens - let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { - max_new_tokens - } else if let Some(truncate) = truncate { - self.max_total_tokens.saturating_sub(truncate) as u32 - } else { - return Err(ValidationError::UnsetMaxNewTokens); - }; - let input_length = truncate.unwrap_or(self.max_input_length); + // Either we don't have a tokenizer or batched_dimension purposefully + // will ignore the actual length in order to schedule the job correctly. + // In this case, we don't know the real length in tokens of the inputs + // However, the inputs will be truncated by the python servers + // We make sure that truncate + max_new_tokens <= self.max_total_tokens + let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { + max_new_tokens + } else if let Some(truncate) = truncate { + self.max_total_tokens.saturating_sub(truncate) as u32 + } else { + return Err(ValidationError::UnsetMaxNewTokens); + }; + let input_length = truncate.unwrap_or(self.max_input_length); - // Validate MaxNewTokens - if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { - return Err(ValidationError::MaxNewTokens( - self.max_total_tokens - self.max_input_length, - max_new_tokens, - )); - } - - Ok((inputs, input_length, max_new_tokens)) + // Validate MaxNewTokens + if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { + return Err(ValidationError::MaxNewTokens( + self.max_total_tokens - self.max_input_length, + max_new_tokens, + )); } + + Ok((inputs, input_length, max_new_tokens)) } /// Validate a payload and get the number of tokens in the input