add logic to queue
This commit is contained in:
parent
4f460e5bfe
commit
a963495315
|
@ -162,7 +162,7 @@ impl State {
|
||||||
|
|
||||||
let mut max_input_length = 0;
|
let mut max_input_length = 0;
|
||||||
let mut prefill_tokens: u32 = 0;
|
let mut prefill_tokens: u32 = 0;
|
||||||
let mut decode_tokens: u32 = 0;
|
let mut max_decode_steps: u32 = u32::MAX;
|
||||||
|
|
||||||
// Pop entries starting from the front of the queue
|
// Pop entries starting from the front of the queue
|
||||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||||
|
@ -182,7 +182,10 @@ impl State {
|
||||||
prefill_tokens += entry.request.input_length;
|
prefill_tokens += entry.request.input_length;
|
||||||
}
|
}
|
||||||
|
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
max_decode_steps =
|
||||||
|
max_decode_steps.min(entry.request.stopping_parameters.max_new_tokens);
|
||||||
|
|
||||||
|
let decode_tokens = max_decode_steps * (batch_requests.len() + 1) as u32;
|
||||||
|
|
||||||
if (prefill_tokens + decode_tokens) > token_budget {
|
if (prefill_tokens + decode_tokens) > token_budget {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
|
@ -236,6 +239,8 @@ impl State {
|
||||||
let size = batch_requests.len() as u32;
|
let size = batch_requests.len() as u32;
|
||||||
next_batch_span.record("batch_size", size);
|
next_batch_span.record("batch_size", size);
|
||||||
|
|
||||||
|
let decode_tokens = size * max_decode_steps;
|
||||||
|
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: self.next_batch_id,
|
id: self.next_batch_id,
|
||||||
requests: batch_requests,
|
requests: batch_requests,
|
||||||
|
|
|
@ -380,30 +380,45 @@ pub enum ValidationError {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests{
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_max_new_tokens(){
|
async fn test_validation_max_new_tokens() {
|
||||||
let tokenizer = None;
|
let tokenizer = None;
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
match validation
|
||||||
|
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens")
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer() -> Tokenizer{
|
async fn get_tokenizer() -> Tokenizer {
|
||||||
if !std::path::Path::new("tokenizer.json").exists(){
|
if !std::path::Path::new("tokenizer.json").exists() {
|
||||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||||
file.write_all(&content).unwrap();
|
file.write_all(&content).unwrap();
|
||||||
}
|
}
|
||||||
|
@ -411,19 +426,29 @@ mod tests{
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_input_length(){
|
async fn test_validation_input_length() {
|
||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
match validation
|
||||||
|
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens")
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue