This commit is contained in:
OlivierDehaene 2023-06-23 15:01:05 +02:00
parent 83e442ca9a
commit a4fd6905d8
1 changed files with 12 additions and 10 deletions

View File

@ -154,22 +154,24 @@ impl Infer {
// Validation // Validation
// logprobs, ids and texts must have the same lengths // logprobs, ids and texts must have the same lengths
if tokens.logprobs.len() != tokens_length || tokens.texts.len() != tokens_length { if tokens.logprobs.len() != tokens_length || tokens.texts.len() != tokens_length
return Err(InferError::GenerationError(format!("Prefill tokens do not have the correct lengths"))) {
return Err(InferError::GenerationError(
"Prefill tokens do not have the correct lengths".to_string(),
));
} }
result_prefill = Vec::with_capacity(tokens_length); result_prefill = Vec::with_capacity(tokens_length);
// Create Token objects // Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster // We do that here instead of in the Python code as Rust for loops are faster
for ((id, logprob), text) in tokens.ids.into_iter().zip( for ((id, logprob), text) in tokens
tokens.logprobs.into_iter() .ids
).zip(tokens.texts.into_iter()) { .into_iter()
result_prefill.push(PrefillToken{ .zip(tokens.logprobs.into_iter())
id, .zip(tokens.texts.into_iter())
text, {
logprob, result_prefill.push(PrefillToken { id, text, logprob });
});
} }
} }
// Push last token // Push last token