add logic to queue

This commit is contained in:
OlivierDehaene 2023-04-26 13:40:20 +02:00
parent 4f460e5bfe
commit a963495315
2 changed files with 45 additions and 15 deletions

View File

@ -162,7 +162,7 @@ impl State {
let mut max_input_length = 0; let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0; let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0; let mut max_decode_steps: u32 = u32::MAX;
// Pop entries starting from the front of the queue // Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() { while let Some((id, mut entry)) = self.entries.pop_front() {
@ -182,7 +182,10 @@ impl State {
prefill_tokens += entry.request.input_length; prefill_tokens += entry.request.input_length;
} }
decode_tokens += entry.request.stopping_parameters.max_new_tokens; max_decode_steps =
max_decode_steps.min(entry.request.stopping_parameters.max_new_tokens);
let decode_tokens = max_decode_steps * (batch_requests.len() + 1) as u32;
if (prefill_tokens + decode_tokens) > token_budget { if (prefill_tokens + decode_tokens) > token_budget {
// Entry is over budget // Entry is over budget
@ -236,6 +239,8 @@ impl State {
let size = batch_requests.len() as u32; let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size); next_batch_span.record("batch_size", size);
let decode_tokens = size * max_decode_steps;
let batch = Batch { let batch = Batch {
id: self.next_batch_id, id: self.next_batch_id,
requests: batch_requests, requests: batch_requests,

View File

@ -380,30 +380,45 @@ pub enum ValidationError {
} }
#[cfg(test)] #[cfg(test)]
mod tests{ mod tests {
use super::*; use super::*;
use std::io::Write; use std::io::Write;
#[tokio::test] #[tokio::test]
async fn test_validation_max_new_tokens(){ async fn test_validation_max_new_tokens() {
let tokenizer = None; let tokenizer = None;
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_input_length = 4;
let max_total_tokens = 5; let max_total_tokens = 5;
let workers = 1; let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_input_length,
max_total_tokens,
);
let max_new_tokens = 10; let max_new_tokens = 10;
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxNewTokens(1, 10)) => (), Err(ValidationError::MaxNewTokens(1, 10)) => (),
_ => panic!("Unexpected not max new tokens") _ => panic!("Unexpected not max new tokens"),
} }
} }
async fn get_tokenizer() -> Tokenizer{ async fn get_tokenizer() -> Tokenizer {
if !std::path::Path::new("tokenizer.json").exists(){ if !std::path::Path::new("tokenizer.json").exists() {
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
.await
.unwrap()
.bytes()
.await
.unwrap();
let mut file = std::fs::File::create("tokenizer.json").unwrap(); let mut file = std::fs::File::create("tokenizer.json").unwrap();
file.write_all(&content).unwrap(); file.write_all(&content).unwrap();
} }
@ -411,19 +426,29 @@ mod tests{
} }
#[tokio::test] #[tokio::test]
async fn test_validation_input_length(){ async fn test_validation_input_length() {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_input_length = 4;
let max_total_tokens = 5; let max_total_tokens = 5;
let workers = 1; let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_input_length,
max_total_tokens,
);
let max_new_tokens = 10; let max_new_tokens = 10;
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens") _ => panic!("Unexpected not max new tokens"),
} }
} }
} }