See #1049

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi <yi.a.wang@intel.com>
This commit is contained in:
OlivierDehaene 2023-10-20 10:28:45 +02:00 committed by GitHub
parent 72b8f88be8
commit 5e28f44a83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 7 deletions

View File

@ -103,17 +103,19 @@ impl Client {
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let mut n_tokens = 0; let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::new();
let mut truncate = 0;
// Create requests // Create requests
while n_tokens < max_prefill_tokens { while n_tokens < max_prefill_tokens {
truncate = min(max_input_length, max_prefill_tokens - n_tokens);
requests.push(Request { requests.push(Request {
id: 0, id: 0,
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize), inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: min(max_input_length, max_prefill_tokens - n_tokens), truncate: truncate,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
@ -126,9 +128,9 @@ impl Client {
watermark: true, watermark: true,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2, max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: false, ignore_eos_token: true,
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20, top_n_tokens: 20,

View File

@ -95,11 +95,14 @@ impl ShardedClient {
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) .map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect(); .collect();
// Take the minimum value // Take the minimum value
let results = join_all(futures) let results = join_all(futures)

View File

@ -212,7 +212,7 @@ fn main() -> Result<(), RouterError> {
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens) .warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
.await .await
.map_err(RouterError::Warmup)? .map_err(RouterError::Warmup)?
{ {

View File

@ -122,7 +122,7 @@ impl Validation {
if let Some(truncate) = truncate { if let Some(truncate) = truncate {
self.max_total_tokens.saturating_sub(truncate) as u32 self.max_total_tokens.saturating_sub(truncate) as u32
} else { } else {
return Err(ValidationError::UnsetMaxNewTokens) return Err(ValidationError::UnsetMaxNewTokens);
} }
}; };
let input_length = truncate.unwrap_or(self.max_input_length); let input_length = truncate.unwrap_or(self.max_input_length);