parent
c0795de2f2
commit
e8bfe199ba
|
@ -56,12 +56,15 @@ pub(crate) struct GenerateParameters {
|
||||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "None", example = false)]
|
#[schema(default = "null", example = false)]
|
||||||
pub return_full_text: Option<bool>,
|
pub return_full_text: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(default = "null", example = "null")]
|
||||||
|
pub truncate: Option<usize>,
|
||||||
|
#[serde(default)]
|
||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub watermark: bool,
|
pub watermark: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -86,6 +89,7 @@ fn default_parameters() -> GenerateParameters {
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
|
truncate: None,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
details: false,
|
details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
|
|
@ -73,6 +73,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
|
truncate: None,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
details: false,
|
details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
|
|
@ -6,6 +6,7 @@ use rand::Rng;
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
use tokenizers::TruncationDirection;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
|
|
||||||
|
@ -157,6 +158,7 @@ fn validate(
|
||||||
do_sample,
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop: stop_sequences,
|
stop: stop_sequences,
|
||||||
|
truncate,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
..
|
..
|
||||||
|
@ -223,50 +225,70 @@ fn validate(
|
||||||
return Err(EmptyInput);
|
return Err(EmptyInput);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Check if truncate is strictly positive and less than max_input_length
|
||||||
match tokenizer.encode(request.inputs.clone(), true) {
|
let truncate = truncate
|
||||||
Ok(encoding) => {
|
.map(|value| {
|
||||||
let input_length = encoding.len();
|
if value == 0 || value > max_input_length {
|
||||||
let total_tokens = input_length + max_new_tokens as usize;
|
return Err(ValidationError::Truncate(max_input_length, value));
|
||||||
|
|
||||||
if input_length > max_input_length {
|
|
||||||
Err(ValidationError::InputLength(max_input_length, input_length))
|
|
||||||
} else if total_tokens > max_total_tokens {
|
|
||||||
Err(ValidationError::MaxTotalTokens(
|
|
||||||
max_total_tokens,
|
|
||||||
input_length,
|
|
||||||
max_new_tokens,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
// Return ValidGenerateRequest
|
|
||||||
let parameters = NextTokenChooserParameters {
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
top_k,
|
|
||||||
top_p,
|
|
||||||
typical_p,
|
|
||||||
do_sample,
|
|
||||||
seed,
|
|
||||||
watermark,
|
|
||||||
};
|
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
|
||||||
max_new_tokens,
|
|
||||||
stop_sequences,
|
|
||||||
};
|
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
|
||||||
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
|
||||||
inputs: request.inputs,
|
|
||||||
input_length: input_length as u32,
|
|
||||||
parameters,
|
|
||||||
stopping_parameters,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
Ok(Some(value))
|
||||||
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
|
})
|
||||||
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
|
// Get the number of tokens in the input
|
||||||
|
let mut encoding = tokenizer
|
||||||
|
.encode(request.inputs.clone(), true)
|
||||||
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
|
let (inputs, input_length) = if let Some(truncate) = truncate {
|
||||||
|
// truncate encoding and decode new inputs
|
||||||
|
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||||
|
let inputs = tokenizer
|
||||||
|
.decode(Vec::from(encoding.get_ids()), false)
|
||||||
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
(inputs, encoding.len())
|
||||||
|
} else {
|
||||||
|
(request.inputs, encoding.len())
|
||||||
|
};
|
||||||
|
|
||||||
|
if input_length > max_input_length {
|
||||||
|
return Err(ValidationError::InputLength(max_input_length, input_length));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
|
if total_tokens > max_total_tokens {
|
||||||
|
return Err(ValidationError::MaxTotalTokens(
|
||||||
|
max_total_tokens,
|
||||||
|
input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return ValidGenerateRequest
|
||||||
|
let parameters = NextTokenChooserParameters {
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
typical_p,
|
||||||
|
do_sample,
|
||||||
|
seed,
|
||||||
|
watermark,
|
||||||
|
};
|
||||||
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
|
max_new_tokens,
|
||||||
|
stop_sequences,
|
||||||
|
};
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||||
|
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
||||||
|
|
||||||
|
Ok(ValidGenerateRequest {
|
||||||
|
inputs,
|
||||||
|
input_length: input_length as u32,
|
||||||
|
parameters,
|
||||||
|
stopping_parameters,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type ValidationRequest = (
|
type ValidationRequest = (
|
||||||
|
@ -293,6 +315,8 @@ pub enum ValidationError {
|
||||||
TopP,
|
TopP,
|
||||||
#[error("`top_k` must be strictly positive")]
|
#[error("`top_k` must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
|
#[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
|
||||||
|
Truncate(usize, usize),
|
||||||
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||||
TypicalP,
|
TypicalP,
|
||||||
#[error("`max_new_tokens` must be strictly positive")]
|
#[error("`max_new_tokens` must be strictly positive")]
|
||||||
|
|
Loading…
Reference in New Issue