fix: force one of max_new_tokens or truncate with slow tokenizer

This commit is contained in:
OlivierDehaene 2023-10-11 10:46:40 +02:00
parent dd304cf14c
commit 20ee71dcf5
1 changed files with 8 additions and 2 deletions

View File

@ -116,12 +116,16 @@ impl Validation {
// In this case, we don't know the real length in tokens of the inputs
// However, the inputs will be truncated by the python servers
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
let input_length = truncate.unwrap_or(self.max_input_length);
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
max_new_tokens
} else {
self.max_total_tokens.saturating_sub(input_length) as u32
if let Some(truncate) = truncate {
self.max_total_tokens.saturating_sub(truncate) as u32
} else {
return Err(ValidationError::UnsetMaxNewTokens)
}
};
let input_length = truncate.unwrap_or(self.max_input_length);
// Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
@ -393,6 +397,8 @@ pub enum ValidationError {
Truncate(usize, usize),
#[error("`typical_p` must be > 0.0 and < 1.0")]
TypicalP,
#[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
UnsetMaxNewTokens,
#[error("`max_new_tokens` must be strictly positive")]
NegativeMaxNewTokens,
#[error("`max_new_tokens` must be <= {0}. Given: {1}")]