WIP
This commit is contained in:
parent
ebeea9daf8
commit
ff5ca67f58
|
@ -99,6 +99,44 @@ impl LlavaNext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait VLMConfig {
|
||||||
|
fn tokenizer_input(&self, height: usize, width: usize) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Idefics;
|
||||||
|
|
||||||
|
impl VLMConfig for Idefics {
|
||||||
|
fn tokenizer_input(&self, _height: usize, _width: usize) -> String {
|
||||||
|
"<image>".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VLMConfig for Idefics2 {
|
||||||
|
fn tokenizer_input(&self, height: usize, width: usize) -> String {
|
||||||
|
let slots = self.get_number_of_features(height, width);
|
||||||
|
let mut tokens = String::new();
|
||||||
|
tokens.push_str("<fake_token_around_image>");
|
||||||
|
tokens.push_str(&"<image>".repeat(slots));
|
||||||
|
tokens.push_str("<fake_token_around_image>");
|
||||||
|
tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VLMConfig for Paligemma {
|
||||||
|
fn tokenizer_input(&self, height: usize, width: usize) -> String {
|
||||||
|
let slots = self.get_number_of_features(height, width);
|
||||||
|
"<image>".repeat(slots)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VLMConfig for LlavaNext {
|
||||||
|
fn tokenizer_input(&self, height: usize, width: usize) -> String {
|
||||||
|
let slots = self.get_number_of_features(height, width);
|
||||||
|
"<image>".repeat(slots)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct ClipVisionModel {
|
pub struct ClipVisionModel {
|
||||||
|
@ -141,7 +179,7 @@ pub enum Config {
|
||||||
LlavaNext(LlavaNext),
|
LlavaNext(LlavaNext),
|
||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
Mistral,
|
Mistral,
|
||||||
Idefics,
|
Idefics(Idefics),
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
|
@ -168,6 +206,18 @@ pub enum Config {
|
||||||
T5,
|
T5,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn vision_config(&self) -> Option<&dyn VLMConfig> {
|
||||||
|
match self {
|
||||||
|
Config::Idefics(config) => Some(config),
|
||||||
|
Config::Idefics2(config) => Some(config),
|
||||||
|
Config::LlavaNext(config) => Some(config),
|
||||||
|
Config::Paligemma(config) => Some(config),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct TextConfig {}
|
pub struct TextConfig {}
|
||||||
|
|
|
@ -522,8 +522,8 @@ fn prepare_input(
|
||||||
config: &Option<Config>,
|
config: &Option<Config>,
|
||||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config.as_ref().and_then(|c| c.vision_config()) {
|
||||||
Some(Config::LlavaNext(config)) => {
|
Some(config) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
|
@ -535,9 +535,8 @@ fn prepare_input(
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = config.get_number_of_features(height, width);
|
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&config.tokenizer_input(height, width));
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() {
|
if start != inputs.len() {
|
||||||
|
@ -546,80 +545,7 @@ fn prepare_input(
|
||||||
}
|
}
|
||||||
(tokenizer_query, input_chunks)
|
(tokenizer_query, input_chunks)
|
||||||
}
|
}
|
||||||
Some(Config::Paligemma(config)) => {
|
None => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||||
let mut input_chunks = Vec::new();
|
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
|
||||||
let mut start = 0;
|
|
||||||
for chunk in RE.find_iter(&inputs) {
|
|
||||||
let chunk_start = chunk.start();
|
|
||||||
let chunk_end = chunk.end();
|
|
||||||
if chunk_start != start {
|
|
||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
|
||||||
}
|
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
|
||||||
let slots = config.get_number_of_features(height, width);
|
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
|
||||||
start = chunk_end;
|
|
||||||
}
|
|
||||||
if start != inputs.len() {
|
|
||||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
|
||||||
}
|
|
||||||
(tokenizer_query, input_chunks)
|
|
||||||
}
|
|
||||||
Some(Config::Idefics2(config)) => {
|
|
||||||
let mut input_chunks = Vec::new();
|
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
|
||||||
let mut start = 0;
|
|
||||||
for chunk in RE.find_iter(&inputs) {
|
|
||||||
let chunk_start = chunk.start();
|
|
||||||
let chunk_end = chunk.end();
|
|
||||||
if chunk_start != start {
|
|
||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
|
||||||
}
|
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
|
||||||
let slots = config.get_number_of_features(height, width);
|
|
||||||
tokenizer_query.push_str("<fake_token_around_image>");
|
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
|
||||||
tokenizer_query.push_str("<fake_token_around_image>");
|
|
||||||
|
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
|
||||||
start = chunk_end;
|
|
||||||
}
|
|
||||||
if start != inputs.len() {
|
|
||||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
|
||||||
}
|
|
||||||
(tokenizer_query, input_chunks)
|
|
||||||
}
|
|
||||||
Some(Config::Idefics) => {
|
|
||||||
let mut input_chunks = Vec::new();
|
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
|
||||||
let mut start = 0;
|
|
||||||
for chunk in RE.find_iter(&inputs) {
|
|
||||||
let chunk_start = chunk.start();
|
|
||||||
let chunk_end = chunk.end();
|
|
||||||
if chunk_start != start {
|
|
||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
|
||||||
}
|
|
||||||
let (data, mimetype, _height, _width) =
|
|
||||||
fetch_image(&inputs[chunk_start..chunk_end])?;
|
|
||||||
let slots = 1;
|
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
|
||||||
start = chunk_end;
|
|
||||||
}
|
|
||||||
if start != inputs.len() {
|
|
||||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
|
||||||
}
|
|
||||||
(tokenizer_query, input_chunks)
|
|
||||||
}
|
|
||||||
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
|
|
Loading…
Reference in New Issue