fix: force one of max_new_tokens or truncate with slow tokenizer
This commit is contained in:
parent
dd304cf14c
commit
20ee71dcf5
|
@ -116,12 +116,16 @@ impl Validation {
|
||||||
// In this case, we don't know the real length in tokens of the inputs
|
// In this case, we don't know the real length in tokens of the inputs
|
||||||
// However, the inputs will be truncated by the python servers
|
// However, the inputs will be truncated by the python servers
|
||||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||||
max_new_tokens
|
max_new_tokens
|
||||||
} else {
|
} else {
|
||||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
if let Some(truncate) = truncate {
|
||||||
|
self.max_total_tokens.saturating_sub(truncate) as u32
|
||||||
|
} else {
|
||||||
|
return Err(ValidationError::UnsetMaxNewTokens)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||||
|
|
||||||
// Validate MaxNewTokens
|
// Validate MaxNewTokens
|
||||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||||
|
@ -393,6 +397,8 @@ pub enum ValidationError {
|
||||||
Truncate(usize, usize),
|
Truncate(usize, usize),
|
||||||
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||||
TypicalP,
|
TypicalP,
|
||||||
|
#[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")]
|
||||||
|
UnsetMaxNewTokens,
|
||||||
#[error("`max_new_tokens` must be strictly positive")]
|
#[error("`max_new_tokens` must be strictly positive")]
|
||||||
NegativeMaxNewTokens,
|
NegativeMaxNewTokens,
|
||||||
#[error("`max_new_tokens` must be <= {0}. Given: {1}")]
|
#[error("`max_new_tokens` must be <= {0}. Given: {1}")]
|
||||||
|
|
Loading…
Reference in New Issue