From ff5ca67f58b2e59c887df5eba1c5a55b2ade0257 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 31 May 2024 16:14:27 +0000 Subject: [PATCH] WIP --- router/src/config.rs | 52 ++++++++++++++++++++++++- router/src/validation.rs | 82 ++-------------------------------------- 2 files changed, 55 insertions(+), 79 deletions(-) diff --git a/router/src/config.rs b/router/src/config.rs index 29fefd5b..7f5b79a5 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -99,6 +99,44 @@ impl LlavaNext { } } +pub trait VLMConfig { + fn tokenizer_input(&self, height: usize, width: usize) -> String; +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Idefics; + +impl VLMConfig for Idefics { + fn tokenizer_input(&self, _height: usize, _width: usize) -> String { + "".to_string() + } +} + +impl VLMConfig for Idefics2 { + fn tokenizer_input(&self, height: usize, width: usize) -> String { + let slots = self.get_number_of_features(height, width); + let mut tokens = String::new(); + tokens.push_str(""); + tokens.push_str(&"".repeat(slots)); + tokens.push_str(""); + tokens + } +} + +impl VLMConfig for Paligemma { + fn tokenizer_input(&self, height: usize, width: usize) -> String { + let slots = self.get_number_of_features(height, width); + "".repeat(slots) + } +} + +impl VLMConfig for LlavaNext { + fn tokenizer_input(&self, height: usize, width: usize) -> String { + let slots = self.get_number_of_features(height, width); + "".repeat(slots) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct ClipVisionModel { @@ -141,7 +179,7 @@ pub enum Config { LlavaNext(LlavaNext), ClipVisionModel(ClipVisionModel), Mistral, - Idefics, + Idefics(Idefics), Idefics2(Idefics2), Ssm, GptBigcode, @@ -168,6 +206,18 @@ pub enum Config { T5, } +impl Config { + pub fn vision_config(&self) -> Option<&dyn VLMConfig> { + match self { + Config::Idefics(config) => Some(config), + Config::Idefics2(config) => Some(config), + Config::LlavaNext(config) => Some(config), + Config::Paligemma(config) => Some(config), + _ => None, + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct TextConfig {} diff --git a/router/src/validation.rs b/router/src/validation.rs index 863bb99b..a48ef0b1 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -522,8 +522,8 @@ fn prepare_input( config: &Option, ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); - let (tokenizer_query, input_chunks) = match config { - Some(Config::LlavaNext(config)) => { + let (tokenizer_query, input_chunks) = match config.as_ref().and_then(|c| c.vision_config()) { + Some(config) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; @@ -535,9 +535,8 @@ fn prepare_input( tokenizer_query.push_str(&inputs[start..chunk_start]); } let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - let slots = config.get_number_of_features(height, width); input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); - tokenizer_query.push_str(&"".repeat(slots)); + tokenizer_query.push_str(&config.tokenizer_input(height, width)); start = chunk_end; } if start != inputs.len() { @@ -546,80 +545,7 @@ fn prepare_input( } (tokenizer_query, input_chunks) } - Some(Config::Paligemma(config)) => { - let mut input_chunks = Vec::new(); - let mut tokenizer_query = String::with_capacity(inputs.len()); - let mut start = 0; - for chunk in RE.find_iter(&inputs) { - let chunk_start = chunk.start(); - let chunk_end = chunk.end(); - if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); - tokenizer_query.push_str(&inputs[start..chunk_start]); - } - let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - let slots = config.get_number_of_features(height, width); - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); - tokenizer_query.push_str(&"".repeat(slots)); - start = chunk_end; - } - if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); - tokenizer_query.push_str(&inputs[start..]); - } - (tokenizer_query, input_chunks) - } - Some(Config::Idefics2(config)) => { - let mut input_chunks = Vec::new(); - let mut tokenizer_query = String::with_capacity(inputs.len()); - let mut start = 0; - for chunk in RE.find_iter(&inputs) { - let chunk_start = chunk.start(); - let chunk_end = chunk.end(); - if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); - tokenizer_query.push_str(&inputs[start..chunk_start]); - } - let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - let slots = config.get_number_of_features(height, width); - tokenizer_query.push_str(""); - tokenizer_query.push_str(&"".repeat(slots)); - tokenizer_query.push_str(""); - - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); - start = chunk_end; - } - if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); - tokenizer_query.push_str(&inputs[start..]); - } - (tokenizer_query, input_chunks) - } - Some(Config::Idefics) => { - let mut input_chunks = Vec::new(); - let mut tokenizer_query = String::with_capacity(inputs.len()); - let mut start = 0; - for chunk in RE.find_iter(&inputs) { - let chunk_start = chunk.start(); - let chunk_end = chunk.end(); - if chunk_start != start { - input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); - tokenizer_query.push_str(&inputs[start..chunk_start]); - } - let (data, mimetype, _height, _width) = - fetch_image(&inputs[chunk_start..chunk_end])?; - let slots = 1; - tokenizer_query.push_str(&"".repeat(slots)); - input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); - start = chunk_end; - } - if start != inputs.len() { - input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); - tokenizer_query.push_str(&inputs[start..]); - } - (tokenizer_query, input_chunks) - } - _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), + None => (inputs.clone(), vec![Chunk::Text(inputs).into()]), }; // Get the number of tokens in the input