This commit is contained in:
Daniël de Kok 2024-05-31 16:14:27 +00:00
parent ebeea9daf8
commit ff5ca67f58
2 changed files with 55 additions and 79 deletions

View File

@ -99,6 +99,44 @@ impl LlavaNext {
} }
} }
pub trait VLMConfig {
fn tokenizer_input(&self, height: usize, width: usize) -> String;
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Idefics;
impl VLMConfig for Idefics {
fn tokenizer_input(&self, _height: usize, _width: usize) -> String {
"<image>".to_string()
}
}
impl VLMConfig for Idefics2 {
fn tokenizer_input(&self, height: usize, width: usize) -> String {
let slots = self.get_number_of_features(height, width);
let mut tokens = String::new();
tokens.push_str("<fake_token_around_image>");
tokens.push_str(&"<image>".repeat(slots));
tokens.push_str("<fake_token_around_image>");
tokens
}
}
impl VLMConfig for Paligemma {
fn tokenizer_input(&self, height: usize, width: usize) -> String {
let slots = self.get_number_of_features(height, width);
"<image>".repeat(slots)
}
}
impl VLMConfig for LlavaNext {
fn tokenizer_input(&self, height: usize, width: usize) -> String {
let slots = self.get_number_of_features(height, width);
"<image>".repeat(slots)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct ClipVisionModel { pub struct ClipVisionModel {
@ -141,7 +179,7 @@ pub enum Config {
LlavaNext(LlavaNext), LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),
Mistral, Mistral,
Idefics, Idefics(Idefics),
Idefics2(Idefics2), Idefics2(Idefics2),
Ssm, Ssm,
GptBigcode, GptBigcode,
@ -168,6 +206,18 @@ pub enum Config {
T5, T5,
} }
impl Config {
pub fn vision_config(&self) -> Option<&dyn VLMConfig> {
match self {
Config::Idefics(config) => Some(config),
Config::Idefics2(config) => Some(config),
Config::LlavaNext(config) => Some(config),
Config::Paligemma(config) => Some(config),
_ => None,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct TextConfig {} pub struct TextConfig {}

View File

@ -522,8 +522,8 @@ fn prepare_input(
config: &Option<Config>, config: &Option<Config>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config.as_ref().and_then(|c| c.vision_config()) {
Some(Config::LlavaNext(config)) => { Some(config) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
@ -535,9 +535,8 @@ fn prepare_input(
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots)); tokenizer_query.push_str(&config.tokenizer_input(height, width));
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() { if start != inputs.len() {
@ -546,80 +545,7 @@ fn prepare_input(
} }
(tokenizer_query, input_chunks) (tokenizer_query, input_chunks)
} }
Some(Config::Paligemma(config)) => { None => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots));
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Idefics2(config)) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Idefics) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, _height, _width) =
fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
}; };
// Get the number of tokens in the input // Get the number of tokens in the input