diff --git a/router/src/validation.rs b/router/src/validation.rs index f82f9670..23f5e750 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -6,7 +6,6 @@ use crate::{ }; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{ImageFormat, ImageReader}; -use itertools::Itertools; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; @@ -720,15 +719,20 @@ fn prepare_input( // Replace image tokens by hashes. The first token of an image // must be specific to the image for prefix caching. + // TODO: disable when not caching prefixes. let mut token_ids = encoding.get_ids().to_owned(); let mut iter = token_ids.iter_mut().filter(|id| **id == image_token_id); for (image_hash, n_tokens) in image_hashes.iter().zip(image_lens.iter()) { - let image_token = iter.next().ok_or(ValidationError::Tokenizer( - "Image token not found".to_string(), - ))?; - *image_token = *image_hash as u32; - // Skip the remaining tokens of the current image. - iter = iter.dropping(n_tokens - 1); + for i in 0..*n_tokens { + match iter.next() { + Some(image_token) => *image_token = *image_hash as u32 + i as u32, + None => { + return Err(ValidationError::Tokenizer( + "Image token not found".to_string(), + )) + } + } + } } let encoding = Encoding::new(