Starting some routing tests. (#233)
This commit is contained in:
parent
323546df1d
commit
45344244cf
|
@ -1,2 +1,3 @@
|
||||||
.idea
|
.idea
|
||||||
target
|
target
|
||||||
|
router/tokenizer.json
|
||||||
|
|
|
@ -378,3 +378,52 @@ pub enum ValidationError {
|
||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests{
|
||||||
|
use super::*;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validation_max_new_tokens(){
|
||||||
|
let tokenizer = None;
|
||||||
|
let max_best_of = 2;
|
||||||
|
let max_stop_sequence = 3;
|
||||||
|
let max_input_length = 4;
|
||||||
|
let max_total_tokens = 5;
|
||||||
|
let workers = 1;
|
||||||
|
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||||
|
|
||||||
|
let max_new_tokens = 10;
|
||||||
|
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
||||||
|
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||||
|
_ => panic!("Unexpected not max new tokens")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_tokenizer() -> Tokenizer{
|
||||||
|
if !std::path::Path::new("tokenizer.json").exists(){
|
||||||
|
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
||||||
|
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||||
|
file.write_all(&content).unwrap();
|
||||||
|
}
|
||||||
|
Tokenizer::from_file("tokenizer.json").unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validation_input_length(){
|
||||||
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
let max_best_of = 2;
|
||||||
|
let max_stop_sequence = 3;
|
||||||
|
let max_input_length = 4;
|
||||||
|
let max_total_tokens = 5;
|
||||||
|
let workers = 1;
|
||||||
|
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||||
|
|
||||||
|
let max_new_tokens = 10;
|
||||||
|
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
||||||
|
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
||||||
|
_ => panic!("Unexpected not max new tokens")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue