feat: auto max_new_tokens (#2803)

* feat: auto max_new_tokens

* update default

* Fixing the tests.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
OlivierDehaene 2024-12-06 05:50:35 +01:00 committed by GitHub
parent 6685e8fcda
commit 8c3669b287
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 100 additions and 23 deletions

View File

@ -436,6 +436,7 @@ mod tests {
stopping_parameters: ValidStoppingParameters {
ignore_eos_token: false,
max_new_tokens: 1,
max_total_new_tokens: 1024,
stop_sequences: vec![],
},
top_n_tokens: 0,

View File

@ -573,6 +573,7 @@ mod tests {
stopping_parameters: ValidStoppingParameters {
ignore_eos_token: false,
max_new_tokens: 1,
max_total_new_tokens: 1024,
stop_sequences: vec![],
},
top_n_tokens: 0,

View File

@ -1013,6 +1013,7 @@
"type": "integer",
"format": "int32",
"description": "The maximum number of tokens that can be generated in the chat completion.",
"default": "1024",
"example": "32",
"nullable": true,
"minimum": 0
@ -1329,7 +1330,8 @@
"type": "integer",
"format": "int32",
"description": "The maximum number of tokens that can be generated in the chat completion.",
"default": "32",
"default": "1024",
"example": "32",
"nullable": true,
"minimum": 0
},
@ -1591,7 +1593,7 @@
"type": "integer",
"format": "int32",
"description": "Maximum number of tokens to generate.",
"default": "100",
"default": "1024",
"example": "20",
"nullable": true,
"minimum": 0

View File

@ -111,21 +111,79 @@ impl Infer {
})?;
// Validate request
let mut local_request = request.clone();
let valid_request = self.validation.validate(request).await.map_err(|err| {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
err
})?;
let seed = valid_request.parameters.seed;
local_request.parameters.seed = Some(seed);
let input_length = valid_request.input_length;
let max_total_new_tokens = valid_request.stopping_parameters.max_total_new_tokens;
let mut generation_stream = self.backend.schedule(valid_request)?;
// Wrap generation stream to update the backend health if the stream contains an error
let final_stream = stream! {
let mut total_generated_tokens = 0;
let mut first_start = None;
let mut first_queued = None;
let mut all_generated_text: Option<GeneratedText> = None;
while let Some(response) = generation_stream.next().await {
yield response.inspect_err(|_err| {
let response = response.inspect_err(|_err| {
self.backend_health.store(false, Ordering::SeqCst);
})
})?;
match response {
InferStreamResponse::Prefill(_) => yield Ok(response),
InferStreamResponse::Intermediate { .. } => {
total_generated_tokens += 1;
yield Ok(response);
}
InferStreamResponse::End { token, top_tokens,generated_text, start, queued } => {
total_generated_tokens += 1;
first_start = first_start.or(Some(start));
first_queued = first_queued.or(Some(queued));
if let Some(v) = all_generated_text.as_mut() {
v.text.push_str(&generated_text.text);
v.generated_tokens = total_generated_tokens;
v.finish_reason = generated_text.finish_reason.clone();
};
if matches!(generated_text.finish_reason, FinishReason::Length) && total_generated_tokens < max_total_new_tokens {
local_request.inputs.push_str(&generated_text.text);
all_generated_text = all_generated_text.or(Some(generated_text));
let valid_request = match self.validation.validate(local_request.clone()).await {
Ok(valid_request) => valid_request,
Err(err) => {
tracing::debug!("Failed to continue request: {err}");
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });
break;
}
};
generation_stream = match self.backend.schedule(valid_request) {
Ok(stream) => {
tracing::debug!("Continue request");
yield Ok(InferStreamResponse::Intermediate { token, top_tokens } );
stream
},
Err(err) => {
tracing::debug!("Failed to continue request: {err}");
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });
break;
}
}
} else {
yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap_or(generated_text), start: first_start.unwrap(), queued: first_queued.unwrap() });
break;
}
}
}
}
};

View File

@ -332,8 +332,8 @@ pub(crate) struct GenerateParameters {
pub do_sample: bool,
/// Maximum number of tokens to generate.
#[serde(default = "default_max_new_tokens")]
#[schema(nullable = true, default = "100", example = "20")]
#[serde(default)]
#[schema(nullable = true, default = "1024", example = "20")]
pub max_new_tokens: Option<u32>,
/// Whether to prepend the prompt to the generated text
@ -392,10 +392,6 @@ pub(crate) struct GenerateParameters {
pub adapter_id: Option<String>,
}
fn default_max_new_tokens() -> Option<u32> {
Some(100)
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
best_of: None,
@ -406,7 +402,7 @@ fn default_parameters() -> GenerateParameters {
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens: default_max_new_tokens(),
max_new_tokens: None,
return_full_text: None,
stop: Vec::new(),
truncate: None,
@ -464,7 +460,7 @@ pub struct CompletionRequest {
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(default = "32")]
#[schema(default = "1024", example = "32")]
pub max_tokens: Option<u32>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
@ -842,7 +838,7 @@ pub(crate) struct ChatRequest {
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(example = "32")]
#[schema(default = "1024", example = "32")]
pub max_tokens: Option<u32>,
/// UNUSED
@ -937,7 +933,7 @@ impl ChatRequest {
} = self;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let max_new_tokens = max_tokens;
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
@ -1328,7 +1324,7 @@ pub struct SimpleToken {
stop: usize,
}
#[derive(Debug, Serialize, ToSchema)]
#[derive(Debug, Serialize, ToSchema, Clone)]
#[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub enum FinishReason {

View File

@ -714,7 +714,7 @@ pub(crate) async fn completions(
..
} = req;
let max_new_tokens = max_tokens.or(Some(100));
let max_new_tokens = max_tokens;
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {

View File

@ -1,4 +1,3 @@
/// Payload validation logic
use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{
@ -12,6 +11,8 @@ use jsonschema::{Draft, JSONSchema};
use outlines_core::json_schema::to_regex as json_schema_to_regex;
use rand::{thread_rng, Rng};
use serde_json::Value;
/// Payload validation logic
use std::cmp::min;
use std::io::Cursor;
use std::iter;
use std::sync::Arc;
@ -21,6 +22,8 @@ use tokio::sync::oneshot;
use tracing::{instrument, Span};
use {once_cell::sync::Lazy, regex::Regex};
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
/// Validation
#[derive(Debug, Clone)]
pub struct Validation {
@ -131,7 +134,7 @@ impl Validation {
add_special_tokens: bool,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {
// If we have a fast tokenizer
let (encoding, inputs) = self
.tokenize(inputs.clone(), add_special_tokens, truncate)
@ -144,10 +147,17 @@ impl Validation {
};
// Get total tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
max_new_tokens
let (max_new_tokens, max_total_new_tokens) = if let Some(max_new_tokens) = max_new_tokens {
(max_new_tokens, max_new_tokens)
} else {
self.max_total_tokens.saturating_sub(input_length) as u32
// Use the maximum possible number of tokens as default
// However, the system will re-queue the request everytime it completes
// `DEFAULT_GENERATION_LENGTH` tokens.
let max_new_tokens = self.max_total_tokens.saturating_sub(input_length) as u32;
(
min(max_new_tokens, DEFAULT_GENERATION_LENGTH),
max_new_tokens,
)
};
let total_tokens = input_length + max_new_tokens as usize;
@ -172,7 +182,13 @@ impl Validation {
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
Ok((
inputs,
Some(input_ids),
input_length,
max_new_tokens,
max_total_new_tokens,
))
}
/// Validate a payload and get the number of tokens in the input
@ -305,7 +321,7 @@ impl Validation {
.unwrap_or(Ok(None))?;
// Validate inputs
let (inputs, input_ids, input_length, max_new_tokens) = self
let (inputs, input_ids, input_length, max_new_tokens, max_total_new_tokens) = self
.validate_input(
request.inputs,
request.add_special_tokens,
@ -381,6 +397,7 @@ impl Validation {
};
let stopping_parameters = ValidStoppingParameters {
max_new_tokens,
max_total_new_tokens,
stop_sequences,
ignore_eos_token: false,
};
@ -740,6 +757,8 @@ pub struct ValidParameters {
pub struct ValidStoppingParameters {
/// / Maximum number of generated tokens
pub max_new_tokens: u32,
/// Maximum number of generated tokens before being re-queued by the system
pub max_total_new_tokens: u32,
/// / Optional stopping sequences
pub stop_sequences: Vec<String>,
/// / Ignore end of sequence token