From e8bfe199bacb2778153bf0029b20dab866733c75 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 9 Mar 2023 13:10:30 +0100 Subject: [PATCH] feat(router): support left truncation (#115) closes #111 --- router/src/lib.rs | 6 ++- router/src/server.rs | 1 + router/src/validation.rs | 108 ++++++++++++++++++++++++--------------- 3 files changed, 72 insertions(+), 43 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index d375eafb..9fcc5085 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -56,12 +56,15 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] - #[schema(default = "None", example = false)] + #[schema(default = "null", example = false)] pub return_full_text: Option, #[serde(default)] #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, #[serde(default)] + #[schema(default = "null", example = "null")] + pub truncate: Option, + #[serde(default)] #[schema(default = "false", example = true)] pub watermark: bool, #[serde(default)] @@ -86,6 +89,7 @@ fn default_parameters() -> GenerateParameters { max_new_tokens: default_max_new_tokens(), return_full_text: None, stop: Vec::new(), + truncate: None, watermark: false, details: false, seed: None, diff --git a/router/src/server.rs b/router/src/server.rs index 2ce5699d..ef10b7b1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -73,6 +73,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json { - let input_length = encoding.len(); - let total_tokens = input_length + max_new_tokens as usize; - - if input_length > max_input_length { - Err(ValidationError::InputLength(max_input_length, input_length)) - } else if total_tokens > max_total_tokens { - Err(ValidationError::MaxTotalTokens( - max_total_tokens, - input_length, - max_new_tokens, - )) - } else { - // Return ValidGenerateRequest - let parameters = NextTokenChooserParameters { - temperature, - repetition_penalty, - top_k, - top_p, - typical_p, - do_sample, - seed, - watermark, - }; - let stopping_parameters = StoppingCriteriaParameters { - max_new_tokens, - stop_sequences, - }; - - metrics::histogram!("tgi_request_input_length", input_length as f64); - metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); - - Ok(ValidGenerateRequest { - inputs: request.inputs, - input_length: input_length as u32, - parameters, - stopping_parameters, - }) + // Check if truncate is strictly positive and less than max_input_length + let truncate = truncate + .map(|value| { + if value == 0 || value > max_input_length { + return Err(ValidationError::Truncate(max_input_length, value)); } - } - Err(err) => Err(ValidationError::Tokenizer(err.to_string())), + Ok(Some(value)) + }) + .unwrap_or(Ok(None))?; + + // Get the number of tokens in the input + let mut encoding = tokenizer + .encode(request.inputs.clone(), true) + .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; + + let (inputs, input_length) = if let Some(truncate) = truncate { + // truncate encoding and decode new inputs + encoding.truncate(truncate, 0, TruncationDirection::Left); + let inputs = tokenizer + .decode(Vec::from(encoding.get_ids()), false) + .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; + (inputs, encoding.len()) + } else { + (request.inputs, encoding.len()) + }; + + if input_length > max_input_length { + return Err(ValidationError::InputLength(max_input_length, input_length)); } + + let total_tokens = input_length + max_new_tokens as usize; + if total_tokens > max_total_tokens { + return Err(ValidationError::MaxTotalTokens( + max_total_tokens, + input_length, + max_new_tokens, + )); + } + + // Return ValidGenerateRequest + let parameters = NextTokenChooserParameters { + temperature, + repetition_penalty, + top_k, + top_p, + typical_p, + do_sample, + seed, + watermark, + }; + let stopping_parameters = StoppingCriteriaParameters { + max_new_tokens, + stop_sequences, + }; + + metrics::histogram!("tgi_request_input_length", input_length as f64); + metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); + + Ok(ValidGenerateRequest { + inputs, + input_length: input_length as u32, + parameters, + stopping_parameters, + }) } type ValidationRequest = ( @@ -293,6 +315,8 @@ pub enum ValidationError { TopP, #[error("`top_k` must be strictly positive")] TopK, + #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")] + Truncate(usize, usize), #[error("`typical_p` must be > 0.0 and < 1.0")] TypicalP, #[error("`max_new_tokens` must be strictly positive")]