From 45344244cf194c89c94fdbdfc1bdca724c7df4e5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Apr 2023 14:13:14 +0200 Subject: [PATCH] Starting some routing tests. (#233) --- .gitignore | 3 ++- router/src/validation.rs | 49 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ec376bb8..19604d42 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea -target \ No newline at end of file +target +router/tokenizer.json diff --git a/router/src/validation.rs b/router/src/validation.rs index 7f2f76e6..983c2612 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -378,3 +378,52 @@ pub enum ValidationError { #[error("tokenizer error {0}")] Tokenizer(String), } + +#[cfg(test)] +mod tests{ + use super::*; + use std::io::Write; + + #[tokio::test] + async fn test_validation_max_new_tokens(){ + let tokenizer = None; + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + + let max_new_tokens = 10; + match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ + Err(ValidationError::MaxNewTokens(1, 10)) => (), + _ => panic!("Unexpected not max new tokens") + } + } + + async fn get_tokenizer() -> Tokenizer{ + if !std::path::Path::new("tokenizer.json").exists(){ + let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); + let mut file = std::fs::File::create("tokenizer.json").unwrap(); + file.write_all(&content).unwrap(); + } + Tokenizer::from_file("tokenizer.json").unwrap() + } + + #[tokio::test] + async fn test_validation_input_length(){ + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + + let max_new_tokens = 10; + match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ + Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), + _ => panic!("Unexpected not max new tokens") + } + } +}