feat(router): support left truncation (#115)

closes #111
This commit is contained in:
OlivierDehaene 2023-03-09 13:10:30 +01:00 committed by GitHub
parent c0795de2f2
commit e8bfe199ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 43 deletions

View File

@ -56,12 +56,15 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32, pub max_new_tokens: u32,
#[serde(default)] #[serde(default)]
#[schema(default = "None", example = false)] #[schema(default = "null", example = false)]
pub return_full_text: Option<bool>, pub return_full_text: Option<bool>,
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))] #[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
#[schema(default = "null", example = "null")]
pub truncate: Option<usize>,
#[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub watermark: bool, pub watermark: bool,
#[serde(default)] #[serde(default)]
@ -86,6 +89,7 @@ fn default_parameters() -> GenerateParameters {
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: Vec::new(),
truncate: None,
watermark: false, watermark: false,
details: false, details: false,
seed: None, seed: None,

View File

@ -73,6 +73,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
max_new_tokens: 1, max_new_tokens: 1,
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: Vec::new(),
truncate: None,
watermark: false, watermark: false,
details: false, details: false,
seed: None, seed: None,

View File

@ -6,6 +6,7 @@ use rand::Rng;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span}; use tracing::{instrument, Span};
@ -157,6 +158,7 @@ fn validate(
do_sample, do_sample,
max_new_tokens, max_new_tokens,
stop: stop_sequences, stop: stop_sequences,
truncate,
seed, seed,
watermark, watermark,
.. ..
@ -223,21 +225,45 @@ fn validate(
return Err(EmptyInput); return Err(EmptyInput);
} }
// Check if truncate is strictly positive and less than max_input_length
let truncate = truncate
.map(|value| {
if value == 0 || value > max_input_length {
return Err(ValidationError::Truncate(max_input_length, value));
}
Ok(Some(value))
})
.unwrap_or(Ok(None))?;
// Get the number of tokens in the input // Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) { let mut encoding = tokenizer
Ok(encoding) => { .encode(request.inputs.clone(), true)
let input_length = encoding.len(); .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
let total_tokens = input_length + max_new_tokens as usize;
let (inputs, input_length) = if let Some(truncate) = truncate {
// truncate encoding and decode new inputs
encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
} else {
(request.inputs, encoding.len())
};
if input_length > max_input_length { if input_length > max_input_length {
Err(ValidationError::InputLength(max_input_length, input_length)) return Err(ValidationError::InputLength(max_input_length, input_length));
} else if total_tokens > max_total_tokens { }
Err(ValidationError::MaxTotalTokens(
let total_tokens = input_length + max_new_tokens as usize;
if total_tokens > max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
max_total_tokens, max_total_tokens,
input_length, input_length,
max_new_tokens, max_new_tokens,
)) ));
} else { }
// Return ValidGenerateRequest // Return ValidGenerateRequest
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
@ -258,15 +284,11 @@ fn validate(
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs: request.inputs, inputs,
input_length: input_length as u32, input_length: input_length as u32,
parameters, parameters,
stopping_parameters, stopping_parameters,
}) })
}
}
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
}
} }
type ValidationRequest = ( type ValidationRequest = (
@ -293,6 +315,8 @@ pub enum ValidationError {
TopP, TopP,
#[error("`top_k` must be strictly positive")] #[error("`top_k` must be strictly positive")]
TopK, TopK,
#[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
Truncate(usize, usize),
#[error("`typical_p` must be > 0.0 and < 1.0")] #[error("`typical_p` must be > 0.0 and < 1.0")]
TypicalP, TypicalP,
#[error("`max_new_tokens` must be strictly positive")] #[error("`max_new_tokens` must be strictly positive")]