From 81fa53f37b416e4946d70db5a15aa97159c75b8d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 8 Feb 2024 15:25:34 +0000 Subject: [PATCH] Fix tests. --- router/src/validation.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/router/src/validation.rs b/router/src/validation.rs index 166f0b88..ad5e0d67 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -107,7 +107,7 @@ impl Validation { ) -> Result<(String, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { - if self.batch_dimension { + if !self.batch_dimension { let input_length = encoding.len(); // Get total tokens @@ -465,6 +465,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 6; let workers = 1; + let batch_dimension = false; let validation = Validation::new( workers, tokenizer, @@ -473,6 +474,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + batch_dimension ); let max_new_tokens = 10; @@ -494,6 +496,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 6; let workers = 1; + let batch_dimension = false; let validation = Validation::new( workers, tokenizer, @@ -502,6 +505,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + batch_dimension, ); let max_new_tokens = 10; @@ -523,6 +527,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 6; let workers = 1; + let batch_dimension = false; let validation = Validation::new( workers, tokenizer, @@ -531,6 +536,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + batch_dimension ); match validation .validate(GenerateRequest { @@ -557,6 +563,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 106; let workers = 1; + let batch_dimension = false; let validation = Validation::new( workers, tokenizer, @@ -565,6 +572,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + batch_dimension ); match validation .validate(GenerateRequest { @@ -620,6 +628,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 106; let workers = 1; + let batch_dimension = false; let validation = Validation::new( workers, tokenizer, @@ -628,6 +637,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + batch_dimension, ); match validation .validate(GenerateRequest {