unique hash for each image token

This commit is contained in:
Daniël de Kok 2024-09-03 12:56:02 +00:00
parent 8c74ee4498
commit 69dd51069f
1 changed files with 11 additions and 7 deletions

View File

@ -6,7 +6,6 @@ use crate::{
}; };
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader}; use image::{ImageFormat, ImageReader};
use itertools::Itertools;
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
@ -720,15 +719,20 @@ fn prepare_input(
// Replace image tokens by hashes. The first token of an image // Replace image tokens by hashes. The first token of an image
// must be specific to the image for prefix caching. // must be specific to the image for prefix caching.
// TODO: disable when not caching prefixes.
let mut token_ids = encoding.get_ids().to_owned(); let mut token_ids = encoding.get_ids().to_owned();
let mut iter = token_ids.iter_mut().filter(|id| **id == image_token_id); let mut iter = token_ids.iter_mut().filter(|id| **id == image_token_id);
for (image_hash, n_tokens) in image_hashes.iter().zip(image_lens.iter()) { for (image_hash, n_tokens) in image_hashes.iter().zip(image_lens.iter()) {
let image_token = iter.next().ok_or(ValidationError::Tokenizer( for i in 0..*n_tokens {
"Image token not found".to_string(), match iter.next() {
))?; Some(image_token) => *image_token = *image_hash as u32 + i as u32,
*image_token = *image_hash as u32; None => {
// Skip the remaining tokens of the current image. return Err(ValidationError::Tokenizer(
iter = iter.dropping(n_tokens - 1); "Image token not found".to_string(),
))
}
}
}
} }
let encoding = Encoding::new( let encoding = Encoding::new(