feat: auto max_new_tokens (#2803)
* feat: auto max_new_tokens * update default * Fixing the tests. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
6685e8fcda
commit
8c3669b287
|
@ -436,6 +436,7 @@ mod tests {
|
|||
stopping_parameters: ValidStoppingParameters {
|
||||
ignore_eos_token: false,
|
||||
max_new_tokens: 1,
|
||||
max_total_new_tokens: 1024,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
|
|
|
@ -573,6 +573,7 @@ mod tests {
|
|||
stopping_parameters: ValidStoppingParameters {
|
||||
ignore_eos_token: false,
|
||||
max_new_tokens: 1,
|
||||
max_total_new_tokens: 1024,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
|
|
|
@ -1013,6 +1013,7 @@
|
|||
"type": "integer",
|
||||
"format": "int32",
|
||||
"description": "The maximum number of tokens that can be generated in the chat completion.",
|
||||
"default": "1024",
|
||||
"example": "32",
|
||||
"nullable": true,
|
||||
"minimum": 0
|
||||
|
@ -1329,7 +1330,8 @@
|
|||
"type": "integer",
|
||||
"format": "int32",
|
||||
"description": "The maximum number of tokens that can be generated in the chat completion.",
|
||||
"default": "32",
|
||||
"default": "1024",
|
||||
"example": "32",
|
||||
"nullable": true,
|
||||
"minimum": 0
|
||||
},
|
||||
|
@ -1591,7 +1593,7 @@
|
|||
"type": "integer",
|
||||
"format": "int32",
|
||||
"description": "Maximum number of tokens to generate.",
|
||||
"default": "100",
|
||||
"default": "1024",
|
||||
"example": "20",
|
||||
"nullable": true,
|
||||
"minimum": 0
|
||||
|
|
|
@ -111,21 +111,79 @@ impl Infer {
|
|||
})?;
|
||||
|
||||
// Validate request
|
||||
let mut local_request = request.clone();
|
||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
})?;
|
||||
|
||||
let seed = valid_request.parameters.seed;
|
||||
local_request.parameters.seed = Some(seed);
|
||||
let input_length = valid_request.input_length;
|
||||
let max_total_new_tokens = valid_request.stopping_parameters.max_total_new_tokens;
|
||||
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||
|
||||
// Wrap generation stream to update the backend health if the stream contains an error
|
||||
let final_stream = stream! {
|
||||
let mut total_generated_tokens = 0;
|
||||
let mut first_start = None;
|
||||
let mut first_queued = None;
|
||||
let mut all_generated_text: Option<GeneratedText> = None;
|
||||
|
||||
while let Some(response) = generation_stream.next().await {
|
||||
yield response.inspect_err(|_err| {
|
||||
let response = response.inspect_err(|_err| {
|
||||
self.backend_health.store(false, Ordering::SeqCst);
|
||||
})
|
||||
})?;
|
||||
|
||||
match response {
|
||||
InferStreamResponse::Prefill(_) => yield Ok(response),
|
||||
InferStreamResponse::Intermediate { .. } => {
|
||||
total_generated_tokens += 1;
|
||||
yield Ok(response);
|
||||
}
|
||||
InferStreamResponse::End { token, top_tokens,generated_text, start, queued } => {
|
||||
total_generated_tokens += 1;
|
||||
first_start = first_start.or(Some(start));
|
||||
first_queued = first_queued.or(Some(queued));
|
||||
if let Some(v) = all_generated_text.as_mut() {
|
||||
v.text.push_str(&generated_text.text);
|
||||
v.generated_tokens = total_generated_tokens;
|
||||
v.finish_reason = generated_text.finish_reason.clone();
|
||||
};
|
||||
|
||||
if matches!(generated_text.finish_reason, FinishReason::Length) && total_generated_tokens < max_total_new_tokens {
|
||||
local_request.inputs.push_str(&generated_text.text);
|
||||
all_generated_text = all_generated_text.or(Some(generated_text));
|
||||
|
||||
let valid_request = match self.validation.validate(local_request.clone()).await {
|
||||
Ok(valid_request) => valid_request,
|
||||
Err(err) => {
|
||||
tracing::debug!("Failed to continue request: {err}");
|
||||
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
generation_stream = match self.backend.schedule(valid_request) {
|
||||
Ok(stream) => {
|
||||
tracing::debug!("Continue request");
|
||||
yield Ok(InferStreamResponse::Intermediate { token, top_tokens } );
|
||||
stream
|
||||
},
|
||||
Err(err) => {
|
||||
tracing::debug!("Failed to continue request: {err}");
|
||||
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap_or(generated_text), start: first_start.unwrap(), queued: first_queued.unwrap() });
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -332,8 +332,8 @@ pub(crate) struct GenerateParameters {
|
|||
pub do_sample: bool,
|
||||
|
||||
/// Maximum number of tokens to generate.
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
#[schema(nullable = true, default = "100", example = "20")]
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "1024", example = "20")]
|
||||
pub max_new_tokens: Option<u32>,
|
||||
|
||||
/// Whether to prepend the prompt to the generated text
|
||||
|
@ -392,10 +392,6 @@ pub(crate) struct GenerateParameters {
|
|||
pub adapter_id: Option<String>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
Some(100)
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
best_of: None,
|
||||
|
@ -406,7 +402,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
max_new_tokens: None,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
|
@ -464,7 +460,7 @@ pub struct CompletionRequest {
|
|||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
#[schema(default = "32")]
|
||||
#[schema(default = "1024", example = "32")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||
|
@ -842,7 +838,7 @@ pub(crate) struct ChatRequest {
|
|||
|
||||
/// The maximum number of tokens that can be generated in the chat completion.
|
||||
#[serde(default)]
|
||||
#[schema(example = "32")]
|
||||
#[schema(default = "1024", example = "32")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// UNUSED
|
||||
|
@ -937,7 +933,7 @@ impl ChatRequest {
|
|||
} = self;
|
||||
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let max_new_tokens = max_tokens;
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
|
@ -1328,7 +1324,7 @@ pub struct SimpleToken {
|
|||
stop: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||
#[serde(rename_all(serialize = "snake_case"))]
|
||||
#[schema(example = "Length")]
|
||||
pub enum FinishReason {
|
||||
|
|
|
@ -714,7 +714,7 @@ pub(crate) async fn completions(
|
|||
..
|
||||
} = req;
|
||||
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let max_new_tokens = max_tokens;
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
/// Payload validation logic
|
||||
use crate::config::Config;
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{
|
||||
|
@ -12,6 +11,8 @@ use jsonschema::{Draft, JSONSchema};
|
|||
use outlines_core::json_schema::to_regex as json_schema_to_regex;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
/// Payload validation logic
|
||||
use std::cmp::min;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use std::sync::Arc;
|
||||
|
@ -21,6 +22,8 @@ use tokio::sync::oneshot;
|
|||
use tracing::{instrument, Span};
|
||||
use {once_cell::sync::Lazy, regex::Regex};
|
||||
|
||||
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
|
||||
|
||||
/// Validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
|
@ -131,7 +134,7 @@ impl Validation {
|
|||
add_special_tokens: bool,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
let (encoding, inputs) = self
|
||||
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
||||
|
@ -144,10 +147,17 @@ impl Validation {
|
|||
};
|
||||
|
||||
// Get total tokens
|
||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||
max_new_tokens
|
||||
let (max_new_tokens, max_total_new_tokens) = if let Some(max_new_tokens) = max_new_tokens {
|
||||
(max_new_tokens, max_new_tokens)
|
||||
} else {
|
||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||
// Use the maximum possible number of tokens as default
|
||||
// However, the system will re-queue the request everytime it completes
|
||||
// `DEFAULT_GENERATION_LENGTH` tokens.
|
||||
let max_new_tokens = self.max_total_tokens.saturating_sub(input_length) as u32;
|
||||
(
|
||||
min(max_new_tokens, DEFAULT_GENERATION_LENGTH),
|
||||
max_new_tokens,
|
||||
)
|
||||
};
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
|
||||
|
@ -172,7 +182,13 @@ impl Validation {
|
|||
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
||||
|
||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||
Ok((
|
||||
inputs,
|
||||
Some(input_ids),
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
max_total_new_tokens,
|
||||
))
|
||||
}
|
||||
|
||||
/// Validate a payload and get the number of tokens in the input
|
||||
|
@ -305,7 +321,7 @@ impl Validation {
|
|||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||
let (inputs, input_ids, input_length, max_new_tokens, max_total_new_tokens) = self
|
||||
.validate_input(
|
||||
request.inputs,
|
||||
request.add_special_tokens,
|
||||
|
@ -381,6 +397,7 @@ impl Validation {
|
|||
};
|
||||
let stopping_parameters = ValidStoppingParameters {
|
||||
max_new_tokens,
|
||||
max_total_new_tokens,
|
||||
stop_sequences,
|
||||
ignore_eos_token: false,
|
||||
};
|
||||
|
@ -740,6 +757,8 @@ pub struct ValidParameters {
|
|||
pub struct ValidStoppingParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
pub max_new_tokens: u32,
|
||||
/// Maximum number of generated tokens before being re-queued by the system
|
||||
pub max_total_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
pub stop_sequences: Vec<String>,
|
||||
/// / Ignore end of sequence token
|
||||
|
|
Loading…
Reference in New Issue