feat(router): add tests to validation (#237)
This commit is contained in:
parent
77758f603b
commit
c4fb09f2ae
|
@ -67,6 +67,9 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
||||||
|
- name: Run Clippy
|
||||||
|
run: |
|
||||||
|
cargo clippy
|
||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
run: |
|
run: |
|
||||||
cargo test
|
cargo test
|
||||||
|
|
|
@ -276,3 +276,19 @@ pub(crate) struct ErrorResponse {
|
||||||
pub error: String,
|
pub error: String,
|
||||||
pub error_type: String,
|
pub error_type: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests{
|
||||||
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
pub(crate) async fn get_tokenizer() -> Tokenizer{
|
||||||
|
if !std::path::Path::new("tokenizer.json").exists(){
|
||||||
|
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
||||||
|
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||||
|
file.write_all(&content).unwrap();
|
||||||
|
}
|
||||||
|
Tokenizer::from_file("tokenizer.json").unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -141,6 +141,7 @@ impl State {
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
||||||
|
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
@ -430,7 +431,17 @@ mod tests {
|
||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
queue.append(entry3);
|
queue.append(entry3);
|
||||||
|
|
||||||
|
// Not enough requests pending
|
||||||
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
assert!(queue.next_batch(Some(2), 2).await.is_none());
|
||||||
|
// Not enough token budget
|
||||||
|
assert!(queue.next_batch(Some(1), 0).await.is_none());
|
||||||
|
// Ok
|
||||||
|
let (entries2, batch2, _) = queue.next_batch(Some(1), 2).await.unwrap();
|
||||||
|
assert_eq!(entries2.len(), 1);
|
||||||
|
assert!(entries2.contains_key(&2));
|
||||||
|
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch2.id, 1);
|
||||||
|
assert_eq!(batch2.size, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
@ -741,3 +741,4 @@ impl From<InferError> for Event {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -382,7 +382,8 @@ pub enum ValidationError {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests{
|
mod tests{
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::io::Write;
|
use crate::default_parameters;
|
||||||
|
use crate::tests::get_tokenizer;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_max_new_tokens(){
|
async fn test_validation_max_new_tokens(){
|
||||||
|
@ -401,15 +402,6 @@ mod tests{
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer() -> Tokenizer{
|
|
||||||
if !std::path::Path::new("tokenizer.json").exists(){
|
|
||||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
|
||||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
|
||||||
file.write_all(&content).unwrap();
|
|
||||||
}
|
|
||||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_input_length(){
|
async fn test_validation_input_length(){
|
||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
@ -426,4 +418,73 @@ mod tests{
|
||||||
_ => panic!("Unexpected not max new tokens")
|
_ => panic!("Unexpected not max new tokens")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validation_best_of_sampling(){
|
||||||
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
let max_best_of = 2;
|
||||||
|
let max_stop_sequence = 3;
|
||||||
|
let max_input_length = 4;
|
||||||
|
let max_total_tokens = 5;
|
||||||
|
let workers = 1;
|
||||||
|
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||||
|
match validation.validate(GenerateRequest{
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters{
|
||||||
|
best_of: Some(2),
|
||||||
|
do_sample: false,
|
||||||
|
..default_parameters()
|
||||||
|
}
|
||||||
|
}).await{
|
||||||
|
Err(ValidationError::BestOfSampling) => (),
|
||||||
|
_ => panic!("Unexpected not best of sampling")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validation_top_p(){
|
||||||
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
let max_best_of = 2;
|
||||||
|
let max_stop_sequence = 3;
|
||||||
|
let max_input_length = 4;
|
||||||
|
let max_total_tokens = 5;
|
||||||
|
let workers = 1;
|
||||||
|
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||||
|
match validation.validate(GenerateRequest{
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters{
|
||||||
|
top_p: Some(1.0),
|
||||||
|
..default_parameters()
|
||||||
|
}
|
||||||
|
}).await{
|
||||||
|
Err(ValidationError::TopP) => (),
|
||||||
|
_ => panic!("Unexpected top_p")
|
||||||
|
}
|
||||||
|
|
||||||
|
match validation.validate(GenerateRequest{
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters{
|
||||||
|
top_p: Some(0.99),
|
||||||
|
max_new_tokens: 1,
|
||||||
|
..default_parameters()
|
||||||
|
}
|
||||||
|
}).await{
|
||||||
|
Ok(_) => (),
|
||||||
|
_ => panic!("Unexpected top_p error")
|
||||||
|
}
|
||||||
|
|
||||||
|
let valid_request = validation.validate(GenerateRequest{
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters{
|
||||||
|
top_p: None,
|
||||||
|
max_new_tokens: 1,
|
||||||
|
..default_parameters()
|
||||||
|
}
|
||||||
|
}).await.unwrap();
|
||||||
|
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||||
|
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue