feat: integrate image tokens into inputs embeds

This commit is contained in:
drbh 2024-08-20 16:33:39 +00:00
parent a200d44e12
commit f9dfd36b92
4 changed files with 217 additions and 150 deletions

View File

@ -112,9 +112,20 @@ pub struct ClipVisionModel {
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics3 {
pub(crate) vision_encoder_max_image_size: usize,
pub(crate) image_seq_len: usize,
pub struct Idefics3 {}
impl Idefics3 {
pub fn get_max_longest_edge(&self) -> usize {
364
}
pub fn get_number_of_features(&self) -> usize {
169
}
pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
1456
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]

View File

@ -594,27 +594,69 @@ fn image_tokens(
Idefics3(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
const GLOBAL_IMG: &str = "<global-img>";
let max_size = config.vision_encoder_max_image_size;
let calc_splits = |dim: usize| ((dim as f64) / max_size as f64).ceil() as usize;
let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();
let num_splits = if height.max(width) > max_size {
Some((calc_splits(width), calc_splits(height)))
// resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio
let (height, width) = if height > max_longest_edge_for_image_resize
|| width > max_longest_edge_for_image_resize
{
let aspect_ratio = height as f32 / width as f32;
if height > width {
(
max_longest_edge_for_image_resize,
(max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize,
)
} else {
(
(max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize,
max_longest_edge_for_image_resize,
)
}
} else {
None
(height, width)
};
let num_splits_h = num_splits.map(|(_, h)| h).unwrap_or(1);
let image_repeat = IMAGE.repeat(config.image_seq_len);
let image_seq_len = config.get_number_of_features();
let max_edge = config.get_max_longest_edge();
if num_splits_h > 1 {
let row = format!("{FAKE}{image_repeat}{FAKE}\n");
let mut image_string = row.repeat(num_splits_h);
image_string.push_str(&format!("\n{FAKE}{image_repeat}{FAKE}"));
image_string
let (image_rows, image_cols) = if height > max_edge || width > max_edge {
(
(height as f32 / max_edge as f32).ceil() as usize,
(width as f32 / max_edge as f32).ceil() as usize,
)
} else {
format!("{FAKE}{image_repeat}{FAKE}")
(0, 0)
};
let mut image_string = String::new();
if image_rows == 0 && image_cols == 0 {
// Single image case
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
} else {
// Split image case
for n_h in 0..image_rows {
for n_w in 0..image_cols {
image_string.push_str(FAKE);
image_string.push_str(&format!("<row_{}_col_{}>", n_h + 1, n_w + 1));
image_string.push_str(&IMAGE.repeat(image_seq_len));
}
image_string.push('\n');
}
image_string.push('\n');
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
}
image_string
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
@ -1249,10 +1291,7 @@ mod tests {
#[tokio::test]
async fn test_idefics2_image_tokens() {
let config = Config::Idefics3(Idefics3 {
vision_encoder_max_image_size: 100,
image_seq_len: 1,
});
let config = Config::Idefics3(Idefics3 {});
let preprocessor_config = Some(&HubPreprocessorConfig::Idefics2Processor(
Idefics2Preprocessor {
@ -1260,117 +1299,34 @@ mod tests {
},
));
let height = 100;
let width = 100;
let height = 1067;
let width = 1600;
let tokens = image_tokens(&config, preprocessor_config, height, width);
assert_eq!(
tokens,
"<fake_token_around_image><image><fake_token_around_image>"
);
}
// get all unique tags `<tag>` from the tokens
let tags: std::collections::HashSet<&str> = tokens
.split(|c| c == '<' || c == '>')
.filter(|s| !s.is_empty())
.collect();
#[tokio::test]
async fn test_idefics3_image_tokens() {
let config = Config::Idefics3(Idefics3 {
vision_encoder_max_image_size: 980,
image_seq_len: 1,
});
assert_eq!(tags.len(), 17); // all below and `\n` and `\n\n`
assert_eq!(tags.contains(&"row_1_col_1"), true);
assert_eq!(tags.contains(&"row_1_col_2"), true);
assert_eq!(tags.contains(&"row_1_col_3"), true);
assert_eq!(tags.contains(&"row_1_col_4"), true);
assert_eq!(tags.contains(&"row_2_col_1"), true);
assert_eq!(tags.contains(&"row_2_col_2"), true);
assert_eq!(tags.contains(&"row_2_col_3"), true);
assert_eq!(tags.contains(&"row_2_col_4"), true);
assert_eq!(tags.contains(&"row_3_col_1"), true);
assert_eq!(tags.contains(&"row_3_col_2"), true);
assert_eq!(tags.contains(&"row_3_col_3"), true);
assert_eq!(tags.contains(&"row_3_col_4"), true);
assert_eq!(tags.contains(&"global-img"), true);
assert_eq!(tags.contains(&"image"), true);
assert_eq!(tags.contains(&"fake_token_around_image"), true);
let preprocessor_config = Some(&HubPreprocessorConfig::Idefics3Processor(
Idefics2Preprocessor {
do_image_splitting: true,
},
));
let height = 100;
let width = 100;
let tokens = image_tokens(&config, preprocessor_config, height, width);
assert_eq!(
tokens,
"<fake_token_around_image><image><fake_token_around_image>"
);
}
#[tokio::test]
async fn test_idefics3_correct_n_fake_tokens() {
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let disable_grammar_support = true;
let workers = 1;
let config = Config::Idefics3(Idefics3 {
vision_encoder_max_image_size: 100,
image_seq_len: 1,
});
let validation = Validation::new(
workers,
tokenizer,
Some(config),
Some(HubPreprocessorConfig::Idefics3Processor(
Idefics2Preprocessor {
do_image_splitting: true,
},
)),
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
disable_grammar_support,
);
let (encoding, chunks) = match validation
.tokenize(
format!(
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
PIXEL_GIF, PIXEL_GIF
),
None,
)
.await
{
Ok(Some((encoding, chunks))) => (encoding, chunks),
_ => panic!("Unexpected tokenization failure"),
};
assert!(
chunks
== vec![
Chunk::Text("test".to_string()).into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into()
],
"Failed to process images",
);
// Verify the number of fake tokens:
//
// - Two images, each surrounded/separated by a fake token tags = 4.
assert_eq!(
encoding
.get_tokens()
.iter()
.filter(|t| *t == "fake")
.count(),
4
);
assert_eq!(tokens.len(), 15_901)
}
}

View File

@ -753,6 +753,19 @@ class Idefics3ForConditionalGeneration(nn.Module):
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@ -835,25 +848,22 @@ class Idefics3ForConditionalGeneration(nn.Module):
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
# TODO: finish implementing the image token replacement
# TODO: remove when prefill image tokens are handled correctly
# * for now dummy tokens are added instead of the image tokens output byt the vision model
mask_size = (input_ids == self.config.image_token_id).sum().item()
unrolled_image_size = (
image_hidden_states.shape[1] * image_hidden_states.shape[2]
)
diff = mask_size - unrolled_image_size
if diff > 0:
print(
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
)
# inputs_embeds = self.inputs_merger(
# input_ids=input_ids,
# inputs_embeds=inputs_embeds,
# image_hidden_states=image_hidden_states,
# )
# import ipdb; ipdb.set_trace()
# num_images, _, vision_hidden_size = image_hidden_states.shape
# special_image_token_mask = input_ids == self.image_token_id
# new_inputs_embeds = inputs_embeds.clone()
# reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to(
# inputs_embeds.dtype
# ) # cast to the dtype of the input_embeds to support quantized models
# new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
# inputs_embeds = new_inputs_embeds
if mask_size == unrolled_image_size:
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,

View File

@ -23,6 +23,75 @@ tracer = trace.get_tracer(__name__)
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>"
IDEFICS3_IMAGE_TOKEN = "<image>"
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
def _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"
text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images
def _prompt_single_image(
image_seq_len, fake_token_around_image, image_token, global_img_token
):
"""Prompt with expanded image tokens for a single image."""
return (
f"{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
def get_image_prompt_string(
image_rows,
image_cols,
image_seq_len,
fake_token_around_image,
image_token,
global_img_token,
):
if image_rows == 0 and image_cols == 0:
return _prompt_single_image(
image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token,
)
return _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
@ -55,8 +124,22 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
image_str *= 5
return image_str
if config.model_type == "idefics3":
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}"
image_str = ""
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
# TODO: avoid using hardcoded values
image_seq_len = 169 # default value
# image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
image_str = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=IDEFICS3_IMAGE_TOKEN,
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
@ -80,6 +163,10 @@ def image_text_replacement_fixup(config, text: str) -> str:
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
if config.model_type == "idefics3":
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
return text
@ -182,7 +269,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
image_inputs = processor.image_processor(images, return_tensors="pt")
image_inputs = processor.image_processor(
images, return_tensors="pt", return_row_col_info=True
)
else:
image_inputs = None
@ -196,9 +285,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
full_text += image_text_replacement(
replacement_text = image_text_replacement(
processor, image_inputs, config, image_id
)
full_text += replacement_text
image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
@ -268,7 +358,7 @@ class VlmCausalLM(FlashCausalLM):
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
# **processor_kwargs,
)
self.batch_class = batch_class
# import ipdb; ipdb.set_trace()