diff --git a/router/src/validation.rs b/router/src/validation.rs index 9adedc5b..d0ea137d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -116,12 +116,16 @@ impl Validation { // In this case, we don't know the real length in tokens of the inputs // However, the inputs will be truncated by the python servers // We make sure that truncate + max_new_tokens <= self.max_total_tokens - let input_length = truncate.unwrap_or(self.max_input_length); let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { max_new_tokens } else { - self.max_total_tokens.saturating_sub(input_length) as u32 + if let Some(truncate) = truncate { + self.max_total_tokens.saturating_sub(truncate) as u32 + } else { + return Err(ValidationError::UnsetMaxNewTokens) + } }; + let input_length = truncate.unwrap_or(self.max_input_length); // Validate MaxNewTokens if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { @@ -393,6 +397,8 @@ pub enum ValidationError { Truncate(usize, usize), #[error("`typical_p` must be > 0.0 and < 1.0")] TypicalP, + #[error("one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use")] + UnsetMaxNewTokens, #[error("`max_new_tokens` must be strictly positive")] NegativeMaxNewTokens, #[error("`max_new_tokens` must be <= {0}. Given: {1}")]