WIP
This commit is contained in:
parent
ebeea9daf8
commit
ff5ca67f58
|
@ -99,6 +99,44 @@ impl LlavaNext {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait VLMConfig {
|
||||
fn tokenizer_input(&self, height: usize, width: usize) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Idefics;
|
||||
|
||||
impl VLMConfig for Idefics {
|
||||
fn tokenizer_input(&self, _height: usize, _width: usize) -> String {
|
||||
"<image>".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl VLMConfig for Idefics2 {
|
||||
fn tokenizer_input(&self, height: usize, width: usize) -> String {
|
||||
let slots = self.get_number_of_features(height, width);
|
||||
let mut tokens = String::new();
|
||||
tokens.push_str("<fake_token_around_image>");
|
||||
tokens.push_str(&"<image>".repeat(slots));
|
||||
tokens.push_str("<fake_token_around_image>");
|
||||
tokens
|
||||
}
|
||||
}
|
||||
|
||||
impl VLMConfig for Paligemma {
|
||||
fn tokenizer_input(&self, height: usize, width: usize) -> String {
|
||||
let slots = self.get_number_of_features(height, width);
|
||||
"<image>".repeat(slots)
|
||||
}
|
||||
}
|
||||
|
||||
impl VLMConfig for LlavaNext {
|
||||
fn tokenizer_input(&self, height: usize, width: usize) -> String {
|
||||
let slots = self.get_number_of_features(height, width);
|
||||
"<image>".repeat(slots)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ClipVisionModel {
|
||||
|
@ -141,7 +179,7 @@ pub enum Config {
|
|||
LlavaNext(LlavaNext),
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Idefics(Idefics),
|
||||
Idefics2(Idefics2),
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
|
@ -168,6 +206,18 @@ pub enum Config {
|
|||
T5,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn vision_config(&self) -> Option<&dyn VLMConfig> {
|
||||
match self {
|
||||
Config::Idefics(config) => Some(config),
|
||||
Config::Idefics2(config) => Some(config),
|
||||
Config::LlavaNext(config) => Some(config),
|
||||
Config::Paligemma(config) => Some(config),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct TextConfig {}
|
||||
|
|
|
@ -522,8 +522,8 @@ fn prepare_input(
|
|||
config: &Option<Config>,
|
||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(Config::LlavaNext(config)) => {
|
||||
let (tokenizer_query, input_chunks) = match config.as_ref().and_then(|c| c.vision_config()) {
|
||||
Some(config) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
@ -535,9 +535,8 @@ fn prepare_input(
|
|||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str(&config.tokenizer_input(height, width));
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
|
@ -546,80 +545,7 @@ fn prepare_input(
|
|||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Paligemma(config)) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Idefics2(config)) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Idefics) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, _height, _width) =
|
||||
fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = 1;
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||
None => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||
};
|
||||
|
||||
// Get the number of tokens in the input
|
||||
|
|
Loading…
Reference in New Issue