See #1049 --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Wang, Yi <yi.a.wang@intel.com>
This commit is contained in:
parent
72b8f88be8
commit
5e28f44a83
|
@ -103,17 +103,19 @@ impl Client {
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
|
let mut truncate = 0;
|
||||||
// Create requests
|
// Create requests
|
||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
|
truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||||
truncate: min(max_input_length, max_prefill_tokens - n_tokens),
|
truncate: truncate,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -126,9 +128,9 @@ impl Client {
|
||||||
watermark: true,
|
watermark: true,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 2,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
|
|
|
@ -95,11 +95,14 @@ impl ShardedClient {
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// Take the minimum value
|
// Take the minimum value
|
||||||
let results = join_all(futures)
|
let results = join_all(futures)
|
||||||
|
|
|
@ -212,7 +212,7 @@ fn main() -> Result<(), RouterError> {
|
||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
let max_supported_batch_total_tokens = match sharded_client
|
let max_supported_batch_total_tokens = match sharded_client
|
||||||
.warmup(max_input_length as u32, max_batch_prefill_tokens)
|
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
|
||||||
.await
|
.await
|
||||||
.map_err(RouterError::Warmup)?
|
.map_err(RouterError::Warmup)?
|
||||||
{
|
{
|
||||||
|
|
|
@ -122,7 +122,7 @@ impl Validation {
|
||||||
if let Some(truncate) = truncate {
|
if let Some(truncate) = truncate {
|
||||||
self.max_total_tokens.saturating_sub(truncate) as u32
|
self.max_total_tokens.saturating_sub(truncate) as u32
|
||||||
} else {
|
} else {
|
||||||
return Err(ValidationError::UnsetMaxNewTokens)
|
return Err(ValidationError::UnsetMaxNewTokens);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||||
|
|
Loading…
Reference in New Issue