feat: integrate image tokens into inputs embeds
This commit is contained in:
parent
a200d44e12
commit
f9dfd36b92
|
@ -112,9 +112,20 @@ pub struct ClipVisionModel {
|
|||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Idefics3 {
|
||||
pub(crate) vision_encoder_max_image_size: usize,
|
||||
pub(crate) image_seq_len: usize,
|
||||
pub struct Idefics3 {}
|
||||
|
||||
impl Idefics3 {
|
||||
pub fn get_max_longest_edge(&self) -> usize {
|
||||
364
|
||||
}
|
||||
|
||||
pub fn get_number_of_features(&self) -> usize {
|
||||
169
|
||||
}
|
||||
|
||||
pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
|
||||
1456
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
|
|
@ -594,27 +594,69 @@ fn image_tokens(
|
|||
Idefics3(config) => {
|
||||
const FAKE: &str = "<fake_token_around_image>";
|
||||
const IMAGE: &str = "<image>";
|
||||
const GLOBAL_IMG: &str = "<global-img>";
|
||||
|
||||
let max_size = config.vision_encoder_max_image_size;
|
||||
let calc_splits = |dim: usize| ((dim as f64) / max_size as f64).ceil() as usize;
|
||||
let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();
|
||||
|
||||
let num_splits = if height.max(width) > max_size {
|
||||
Some((calc_splits(width), calc_splits(height)))
|
||||
// resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio
|
||||
let (height, width) = if height > max_longest_edge_for_image_resize
|
||||
|| width > max_longest_edge_for_image_resize
|
||||
{
|
||||
let aspect_ratio = height as f32 / width as f32;
|
||||
if height > width {
|
||||
(
|
||||
max_longest_edge_for_image_resize,
|
||||
(max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
(
|
||||
(max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize,
|
||||
max_longest_edge_for_image_resize,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
(height, width)
|
||||
};
|
||||
|
||||
let num_splits_h = num_splits.map(|(_, h)| h).unwrap_or(1);
|
||||
let image_repeat = IMAGE.repeat(config.image_seq_len);
|
||||
let image_seq_len = config.get_number_of_features();
|
||||
let max_edge = config.get_max_longest_edge();
|
||||
|
||||
if num_splits_h > 1 {
|
||||
let row = format!("{FAKE}{image_repeat}{FAKE}\n");
|
||||
let mut image_string = row.repeat(num_splits_h);
|
||||
image_string.push_str(&format!("\n{FAKE}{image_repeat}{FAKE}"));
|
||||
image_string
|
||||
let (image_rows, image_cols) = if height > max_edge || width > max_edge {
|
||||
(
|
||||
(height as f32 / max_edge as f32).ceil() as usize,
|
||||
(width as f32 / max_edge as f32).ceil() as usize,
|
||||
)
|
||||
} else {
|
||||
format!("{FAKE}{image_repeat}{FAKE}")
|
||||
(0, 0)
|
||||
};
|
||||
|
||||
let mut image_string = String::new();
|
||||
|
||||
if image_rows == 0 && image_cols == 0 {
|
||||
// Single image case
|
||||
image_string.push_str(FAKE);
|
||||
image_string.push_str(GLOBAL_IMG);
|
||||
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||
image_string.push_str(FAKE);
|
||||
} else {
|
||||
// Split image case
|
||||
for n_h in 0..image_rows {
|
||||
for n_w in 0..image_cols {
|
||||
image_string.push_str(FAKE);
|
||||
image_string.push_str(&format!("<row_{}_col_{}>", n_h + 1, n_w + 1));
|
||||
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||
}
|
||||
image_string.push('\n');
|
||||
}
|
||||
|
||||
image_string.push('\n');
|
||||
image_string.push_str(FAKE);
|
||||
image_string.push_str(GLOBAL_IMG);
|
||||
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||
image_string.push_str(FAKE);
|
||||
}
|
||||
|
||||
image_string
|
||||
}
|
||||
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||
|
@ -1249,10 +1291,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_idefics2_image_tokens() {
|
||||
let config = Config::Idefics3(Idefics3 {
|
||||
vision_encoder_max_image_size: 100,
|
||||
image_seq_len: 1,
|
||||
});
|
||||
let config = Config::Idefics3(Idefics3 {});
|
||||
|
||||
let preprocessor_config = Some(&HubPreprocessorConfig::Idefics2Processor(
|
||||
Idefics2Preprocessor {
|
||||
|
@ -1260,117 +1299,34 @@ mod tests {
|
|||
},
|
||||
));
|
||||
|
||||
let height = 100;
|
||||
let width = 100;
|
||||
let height = 1067;
|
||||
let width = 1600;
|
||||
|
||||
let tokens = image_tokens(&config, preprocessor_config, height, width);
|
||||
|
||||
assert_eq!(
|
||||
tokens,
|
||||
"<fake_token_around_image><image><fake_token_around_image>"
|
||||
);
|
||||
}
|
||||
// get all unique tags `<tag>` from the tokens
|
||||
let tags: std::collections::HashSet<&str> = tokens
|
||||
.split(|c| c == '<' || c == '>')
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_idefics3_image_tokens() {
|
||||
let config = Config::Idefics3(Idefics3 {
|
||||
vision_encoder_max_image_size: 980,
|
||||
image_seq_len: 1,
|
||||
});
|
||||
assert_eq!(tags.len(), 17); // all below and `\n` and `\n\n`
|
||||
assert_eq!(tags.contains(&"row_1_col_1"), true);
|
||||
assert_eq!(tags.contains(&"row_1_col_2"), true);
|
||||
assert_eq!(tags.contains(&"row_1_col_3"), true);
|
||||
assert_eq!(tags.contains(&"row_1_col_4"), true);
|
||||
assert_eq!(tags.contains(&"row_2_col_1"), true);
|
||||
assert_eq!(tags.contains(&"row_2_col_2"), true);
|
||||
assert_eq!(tags.contains(&"row_2_col_3"), true);
|
||||
assert_eq!(tags.contains(&"row_2_col_4"), true);
|
||||
assert_eq!(tags.contains(&"row_3_col_1"), true);
|
||||
assert_eq!(tags.contains(&"row_3_col_2"), true);
|
||||
assert_eq!(tags.contains(&"row_3_col_3"), true);
|
||||
assert_eq!(tags.contains(&"row_3_col_4"), true);
|
||||
assert_eq!(tags.contains(&"global-img"), true);
|
||||
assert_eq!(tags.contains(&"image"), true);
|
||||
assert_eq!(tags.contains(&"fake_token_around_image"), true);
|
||||
|
||||
let preprocessor_config = Some(&HubPreprocessorConfig::Idefics3Processor(
|
||||
Idefics2Preprocessor {
|
||||
do_image_splitting: true,
|
||||
},
|
||||
));
|
||||
|
||||
let height = 100;
|
||||
let width = 100;
|
||||
|
||||
let tokens = image_tokens(&config, preprocessor_config, height, width);
|
||||
|
||||
assert_eq!(
|
||||
tokens,
|
||||
"<fake_token_around_image><image><fake_token_around_image>"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_idefics3_correct_n_fake_tokens() {
|
||||
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
|
||||
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let disable_grammar_support = true;
|
||||
let workers = 1;
|
||||
let config = Config::Idefics3(Idefics3 {
|
||||
vision_encoder_max_image_size: 100,
|
||||
image_seq_len: 1,
|
||||
});
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
Some(config),
|
||||
Some(HubPreprocessorConfig::Idefics3Processor(
|
||||
Idefics2Preprocessor {
|
||||
do_image_splitting: true,
|
||||
},
|
||||
)),
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
);
|
||||
|
||||
let (encoding, chunks) = match validation
|
||||
.tokenize(
|
||||
format!(
|
||||
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
|
||||
PIXEL_GIF, PIXEL_GIF
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some((encoding, chunks))) => (encoding, chunks),
|
||||
_ => panic!("Unexpected tokenization failure"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
chunks
|
||||
== vec![
|
||||
Chunk::Text("test".to_string()).into(),
|
||||
Chunk::Image(Image {
|
||||
data: pixel_data.clone(),
|
||||
mimetype: "image/gif".to_string()
|
||||
})
|
||||
.into(),
|
||||
Chunk::Image(Image {
|
||||
data: pixel_data.clone(),
|
||||
mimetype: "image/gif".to_string()
|
||||
})
|
||||
.into()
|
||||
],
|
||||
"Failed to process images",
|
||||
);
|
||||
|
||||
// Verify the number of fake tokens:
|
||||
//
|
||||
// - Two images, each surrounded/separated by a fake token tags = 4.
|
||||
assert_eq!(
|
||||
encoding
|
||||
.get_tokens()
|
||||
.iter()
|
||||
.filter(|t| *t == "fake")
|
||||
.count(),
|
||||
4
|
||||
);
|
||||
assert_eq!(tokens.len(), 15_901)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -753,6 +753,19 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
# mask = input_ids == self.config.image_token_index
|
||||
mask = input_ids == self.config.image_token_id
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
@ -835,25 +848,22 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
# TODO: finish implementing the image token replacement
|
||||
# TODO: remove when prefill image tokens are handled correctly
|
||||
# * for now dummy tokens are added instead of the image tokens output byt the vision model
|
||||
mask_size = (input_ids == self.config.image_token_id).sum().item()
|
||||
unrolled_image_size = (
|
||||
image_hidden_states.shape[1] * image_hidden_states.shape[2]
|
||||
)
|
||||
diff = mask_size - unrolled_image_size
|
||||
if diff > 0:
|
||||
print(
|
||||
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
|
||||
)
|
||||
|
||||
# inputs_embeds = self.inputs_merger(
|
||||
# input_ids=input_ids,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# image_hidden_states=image_hidden_states,
|
||||
# )
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# num_images, _, vision_hidden_size = image_hidden_states.shape
|
||||
# special_image_token_mask = input_ids == self.image_token_id
|
||||
# new_inputs_embeds = inputs_embeds.clone()
|
||||
# reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to(
|
||||
# inputs_embeds.dtype
|
||||
# ) # cast to the dtype of the input_embeds to support quantized models
|
||||
# new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
|
||||
# inputs_embeds = new_inputs_embeds
|
||||
if mask_size == unrolled_image_size:
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
|
|
@ -23,6 +23,75 @@ tracer = trace.get_tracer(__name__)
|
|||
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS2_IMAGE_TOKEN = "<image>"
|
||||
|
||||
IDEFICS3_IMAGE_TOKEN = "<image>"
|
||||
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||
|
||||
|
||||
def _prompt_split_image(
|
||||
image_seq_len,
|
||||
image_rows,
|
||||
image_cols,
|
||||
fake_token_around_image,
|
||||
image_token,
|
||||
global_img_token,
|
||||
):
|
||||
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||||
text_split_images = ""
|
||||
for n_h in range(image_rows):
|
||||
for n_w in range(image_cols):
|
||||
text_split_images += (
|
||||
f"{fake_token_around_image}"
|
||||
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
||||
+ f"{image_token}" * image_seq_len
|
||||
)
|
||||
text_split_images += "\n"
|
||||
|
||||
text_split_images += (
|
||||
f"\n{fake_token_around_image}"
|
||||
+ f"{global_img_token}"
|
||||
+ f"{image_token}" * image_seq_len
|
||||
+ f"{fake_token_around_image}"
|
||||
)
|
||||
return text_split_images
|
||||
|
||||
|
||||
def _prompt_single_image(
|
||||
image_seq_len, fake_token_around_image, image_token, global_img_token
|
||||
):
|
||||
"""Prompt with expanded image tokens for a single image."""
|
||||
return (
|
||||
f"{fake_token_around_image}"
|
||||
+ f"{global_img_token}"
|
||||
+ f"{image_token}" * image_seq_len
|
||||
+ f"{fake_token_around_image}"
|
||||
)
|
||||
|
||||
|
||||
def get_image_prompt_string(
|
||||
image_rows,
|
||||
image_cols,
|
||||
image_seq_len,
|
||||
fake_token_around_image,
|
||||
image_token,
|
||||
global_img_token,
|
||||
):
|
||||
if image_rows == 0 and image_cols == 0:
|
||||
return _prompt_single_image(
|
||||
image_seq_len,
|
||||
fake_token_around_image=fake_token_around_image,
|
||||
image_token=image_token,
|
||||
global_img_token=global_img_token,
|
||||
)
|
||||
return _prompt_split_image(
|
||||
image_seq_len,
|
||||
image_rows,
|
||||
image_cols,
|
||||
fake_token_around_image,
|
||||
image_token,
|
||||
global_img_token,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
|
@ -55,8 +124,22 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||
image_str *= 5
|
||||
return image_str
|
||||
if config.model_type == "idefics3":
|
||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}"
|
||||
image_str = ""
|
||||
# TODO: implement this in a more general way
|
||||
n_rows = image_input["rows"][0][image_id]
|
||||
n_cols = image_input["cols"][0][image_id]
|
||||
|
||||
# TODO: avoid using hardcoded values
|
||||
image_seq_len = 169 # default value
|
||||
# image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
|
||||
|
||||
image_str = get_image_prompt_string(
|
||||
n_rows,
|
||||
n_cols,
|
||||
image_seq_len,
|
||||
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||
)
|
||||
return image_str
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
|
@ -80,6 +163,10 @@ def image_text_replacement_fixup(config, text: str) -> str:
|
|||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
if config.model_type == "idefics3":
|
||||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
|
@ -182,7 +269,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if images:
|
||||
image_inputs = processor.image_processor(images, return_tensors="pt")
|
||||
image_inputs = processor.image_processor(
|
||||
images, return_tensors="pt", return_row_col_info=True
|
||||
)
|
||||
else:
|
||||
image_inputs = None
|
||||
|
||||
|
@ -196,9 +285,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
full_text += image_text_replacement(
|
||||
replacement_text = image_text_replacement(
|
||||
processor, image_inputs, config, image_id
|
||||
)
|
||||
full_text += replacement_text
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
|
@ -268,7 +358,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**processor_kwargs,
|
||||
# **processor_kwargs,
|
||||
)
|
||||
self.batch_class = batch_class
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
|
Loading…
Reference in New Issue