fix: force one of max_new_tokens or truncate with slow tokenizer
This commit is contained in:
parent
dd304cf14c
commit
20ee71dcf5
|
@ -116,12 +116,16 @@ impl Validation {
|
|||
// In this case, we don't know the real length in tokens of the inputs
|
||||
// However, the inputs will be truncated by the python servers
|
||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||
max_new_tokens
|
||||
} else {
|
||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||
if let Some(truncate) = truncate {
|
||||
self.max_total_tokens.saturating_sub(truncate) as u32
|
||||
} else {
|
||||
return Err(ValidationError::UnsetMaxNewTokens)
|
||||
}
|
||||
};
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
|
||||
// Validate MaxNewTokens
|
||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||
|
@ -393,6 +397,8 @@ pub enum ValidationError {
|
|||
Truncate(usize, usize),
|
||||
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||
TypicalP,
|
||||
#[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
|
||||
UnsetMaxNewTokens,
|
||||
#[error("`max_new_tokens` must be strictly positive")]
|
||||
NegativeMaxNewTokens,
|
||||
#[error("`max_new_tokens` must be <= {0}. Given: {1}")]
|
||||
|
|
Loading…
Reference in New Issue