Fix tests.

This commit is contained in:
Nicolas Patry 2024-02-08 15:25:34 +00:00
parent 40f693b6b9
commit 81fa53f37b
1 changed files with 11 additions and 1 deletions

View File

@ -107,7 +107,7 @@ impl Validation {
) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
if self.batch_dimension {
if !self.batch_dimension {
let input_length = encoding.len();
// Get total tokens
@ -465,6 +465,7 @@ mod tests {
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let batch_dimension = false;
let validation = Validation::new(
workers,
tokenizer,
@ -473,6 +474,7 @@ mod tests {
max_top_n_tokens,
max_input_length,
max_total_tokens,
batch_dimension
);
let max_new_tokens = 10;
@ -494,6 +496,7 @@ mod tests {
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let batch_dimension = false;
let validation = Validation::new(
workers,
tokenizer,
@ -502,6 +505,7 @@ mod tests {
max_top_n_tokens,
max_input_length,
max_total_tokens,
batch_dimension,
);
let max_new_tokens = 10;
@ -523,6 +527,7 @@ mod tests {
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let batch_dimension = false;
let validation = Validation::new(
workers,
tokenizer,
@ -531,6 +536,7 @@ mod tests {
max_top_n_tokens,
max_input_length,
max_total_tokens,
batch_dimension
);
match validation
.validate(GenerateRequest {
@ -557,6 +563,7 @@ mod tests {
let max_input_length = 5;
let max_total_tokens = 106;
let workers = 1;
let batch_dimension = false;
let validation = Validation::new(
workers,
tokenizer,
@ -565,6 +572,7 @@ mod tests {
max_top_n_tokens,
max_input_length,
max_total_tokens,
batch_dimension
);
match validation
.validate(GenerateRequest {
@ -620,6 +628,7 @@ mod tests {
let max_input_length = 5;
let max_total_tokens = 106;
let workers = 1;
let batch_dimension = false;
let validation = Validation::new(
workers,
tokenizer,
@ -628,6 +637,7 @@ mod tests {
max_top_n_tokens,
max_input_length,
max_total_tokens,
batch_dimension,
);
match validation
.validate(GenerateRequest {