diff --git a/router/src/queue.rs b/router/src/queue.rs index d970ebf1..f9de592b 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -162,7 +162,7 @@ impl State { let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; - let mut decode_tokens: u32 = 0; + let mut max_decode_steps: u32 = u32::MAX; // Pop entries starting from the front of the queue while let Some((id, mut entry)) = self.entries.pop_front() { @@ -182,7 +182,10 @@ impl State { prefill_tokens += entry.request.input_length; } - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + max_decode_steps = + max_decode_steps.min(entry.request.stopping_parameters.max_new_tokens); + + let decode_tokens = max_decode_steps * (batch_requests.len() + 1) as u32; if (prefill_tokens + decode_tokens) > token_budget { // Entry is over budget @@ -236,6 +239,8 @@ impl State { let size = batch_requests.len() as u32; next_batch_span.record("batch_size", size); + let decode_tokens = size * max_decode_steps; + let batch = Batch { id: self.next_batch_id, requests: batch_requests, diff --git a/router/src/validation.rs b/router/src/validation.rs index 983c2612..37147a03 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -380,50 +380,75 @@ pub enum ValidationError { } #[cfg(test)] -mod tests{ +mod tests { use super::*; use std::io::Write; #[tokio::test] - async fn test_validation_max_new_tokens(){ + async fn test_validation_max_new_tokens() { let tokenizer = None; let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); let max_new_tokens = 10; - match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ + match validation + .validate_input("Hello".to_string(), None, max_new_tokens) + .await + { Err(ValidationError::MaxNewTokens(1, 10)) => (), - _ => panic!("Unexpected not max new tokens") + _ => panic!("Unexpected not max new tokens"), } } - async fn get_tokenizer() -> Tokenizer{ - if !std::path::Path::new("tokenizer.json").exists(){ - let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); - let mut file = std::fs::File::create("tokenizer.json").unwrap(); + async fn get_tokenizer() -> Tokenizer { + if !std::path::Path::new("tokenizer.json").exists() { + let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json") + .await + .unwrap() + .bytes() + .await + .unwrap(); + let mut file = std::fs::File::create("tokenizer.json").unwrap(); file.write_all(&content).unwrap(); } Tokenizer::from_file("tokenizer.json").unwrap() } #[tokio::test] - async fn test_validation_input_length(){ + async fn test_validation_input_length() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); let max_new_tokens = 10; - match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ + match validation + .validate_input("Hello".to_string(), None, max_new_tokens) + .await + { Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), - _ => panic!("Unexpected not max new tokens") + _ => panic!("Unexpected not max new tokens"), } } }