diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e82e8b20..a2c2b7fb 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -67,6 +67,9 @@ jobs: run: | pip install pytest HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests + - name: Run Clippy + run: | + cargo clippy - name: Run Rust tests run: | cargo test diff --git a/router/src/lib.rs b/router/src/lib.rs index 7a1707d9..85b13cfa 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -276,3 +276,19 @@ pub(crate) struct ErrorResponse { pub error: String, pub error_type: String, } + +#[cfg(test)] +mod tests{ + use std::io::Write; + use tokenizers::Tokenizer; + + pub(crate) async fn get_tokenizer() -> Tokenizer{ + if !std::path::Path::new("tokenizer.json").exists(){ + let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); + let mut file = std::fs::File::create("tokenizer.json").unwrap(); + file.write_all(&content).unwrap(); + } + Tokenizer::from_file("tokenizer.json").unwrap() + } +} + diff --git a/router/src/queue.rs b/router/src/queue.rs index d970ebf1..d3f118d8 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -141,6 +141,7 @@ impl State { // Get the next batch fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { + if self.entries.is_empty() { return None; } @@ -430,7 +431,17 @@ mod tests { let (entry3, _guard3) = default_entry(); queue.append(entry3); + // Not enough requests pending assert!(queue.next_batch(Some(2), 2).await.is_none()); + // Not enough token budget + assert!(queue.next_batch(Some(1), 0).await.is_none()); + // Ok + let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap(); + assert_eq!(entries2.len(), 1); + assert!(entries2.contains_key(&2)); + assert!(entries2.get(&2).unwrap().batch_time.is_some()); + assert_eq!(batch2.id, 1); + assert_eq!(batch2.size, 1); } #[tokio::test] diff --git a/router/src/server.rs b/router/src/server.rs index 9540ba18..09b5c3ba 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -741,3 +741,4 @@ impl From for Event { .unwrap() } } + diff --git a/router/src/validation.rs b/router/src/validation.rs index 983c2612..ff2fe89d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -382,7 +382,8 @@ pub enum ValidationError { #[cfg(test)] mod tests{ use super::*; - use std::io::Write; + use crate::default_parameters; + use crate::tests::get_tokenizer; #[tokio::test] async fn test_validation_max_new_tokens(){ @@ -401,15 +402,6 @@ mod tests{ } } - async fn get_tokenizer() -> Tokenizer{ - if !std::path::Path::new("tokenizer.json").exists(){ - let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); - let mut file = std::fs::File::create("tokenizer.json").unwrap(); - file.write_all(&content).unwrap(); - } - Tokenizer::from_file("tokenizer.json").unwrap() - } - #[tokio::test] async fn test_validation_input_length(){ let tokenizer = Some(get_tokenizer().await); @@ -426,4 +418,73 @@ mod tests{ _ => panic!("Unexpected not max new tokens") } } + + #[tokio::test] + async fn test_validation_best_of_sampling(){ + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + best_of: Some(2), + do_sample: false, + ..default_parameters() + } + }).await{ + Err(ValidationError::BestOfSampling) => (), + _ => panic!("Unexpected not best of sampling") + } + + } + + #[tokio::test] + async fn test_validation_top_p(){ + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequence = 3; + let max_input_length = 4; + let max_total_tokens = 5; + let workers = 1; + let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: Some(1.0), + ..default_parameters() + } + }).await{ + Err(ValidationError::TopP) => (), + _ => panic!("Unexpected top_p") + } + + match validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: Some(0.99), + max_new_tokens: 1, + ..default_parameters() + } + }).await{ + Ok(_) => (), + _ => panic!("Unexpected top_p error") + } + + let valid_request = validation.validate(GenerateRequest{ + inputs: "Hello".to_string(), + parameters: GenerateParameters{ + top_p: None, + max_new_tokens: 1, + ..default_parameters() + } + }).await.unwrap(); + // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. + assert_eq!(valid_request.parameters.top_p, 1.0); + + + } }