From f9dfd36b92e8e616a6ae4f6ce74210ae1b844fd3 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 20 Aug 2024 16:33:39 +0000 Subject: [PATCH] feat: integrate image tokens into inputs embeds --- router/src/config.rs | 17 +- router/src/validation.rs | 204 +++++++----------- .../models/custom_modeling/idefics2.py | 46 ++-- .../models/vlm_causal_lm.py | 100 ++++++++- 4 files changed, 217 insertions(+), 150 deletions(-) diff --git a/router/src/config.rs b/router/src/config.rs index 4e754c66..031c76b6 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -112,9 +112,20 @@ pub struct ClipVisionModel { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -pub struct Idefics3 { - pub(crate) vision_encoder_max_image_size: usize, - pub(crate) image_seq_len: usize, +pub struct Idefics3 {} + +impl Idefics3 { + pub fn get_max_longest_edge(&self) -> usize { + 364 + } + + pub fn get_number_of_features(&self) -> usize { + 169 + } + + pub fn get_max_longest_edge_for_image_resize(&self) -> usize { + 1456 + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/router/src/validation.rs b/router/src/validation.rs index ce02376b..fe98fc2a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -594,27 +594,69 @@ fn image_tokens( Idefics3(config) => { const FAKE: &str = ""; const IMAGE: &str = ""; + const GLOBAL_IMG: &str = ""; - let max_size = config.vision_encoder_max_image_size; - let calc_splits = |dim: usize| ((dim as f64) / max_size as f64).ceil() as usize; + let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize(); - let num_splits = if height.max(width) > max_size { - Some((calc_splits(width), calc_splits(height))) + // resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio + let (height, width) = if height > max_longest_edge_for_image_resize + || width > max_longest_edge_for_image_resize + { + let aspect_ratio = height as f32 / width as f32; + if height > width { + ( + max_longest_edge_for_image_resize, + (max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize, + ) + } else { + ( + (max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize, + max_longest_edge_for_image_resize, + ) + } } else { - None + (height, width) }; - let num_splits_h = num_splits.map(|(_, h)| h).unwrap_or(1); - let image_repeat = IMAGE.repeat(config.image_seq_len); + let image_seq_len = config.get_number_of_features(); + let max_edge = config.get_max_longest_edge(); - if num_splits_h > 1 { - let row = format!("{FAKE}{image_repeat}{FAKE}\n"); - let mut image_string = row.repeat(num_splits_h); - image_string.push_str(&format!("\n{FAKE}{image_repeat}{FAKE}")); - image_string + let (image_rows, image_cols) = if height > max_edge || width > max_edge { + ( + (height as f32 / max_edge as f32).ceil() as usize, + (width as f32 / max_edge as f32).ceil() as usize, + ) } else { - format!("{FAKE}{image_repeat}{FAKE}") + (0, 0) + }; + + let mut image_string = String::new(); + + if image_rows == 0 && image_cols == 0 { + // Single image case + image_string.push_str(FAKE); + image_string.push_str(GLOBAL_IMG); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + image_string.push_str(FAKE); + } else { + // Split image case + for n_h in 0..image_rows { + for n_w in 0..image_cols { + image_string.push_str(FAKE); + image_string.push_str(&format!("", n_h + 1, n_w + 1)); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + } + image_string.push('\n'); + } + + image_string.push('\n'); + image_string.push_str(FAKE); + image_string.push_str(GLOBAL_IMG); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + image_string.push_str(FAKE); } + + image_string } Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), @@ -1249,10 +1291,7 @@ mod tests { #[tokio::test] async fn test_idefics2_image_tokens() { - let config = Config::Idefics3(Idefics3 { - vision_encoder_max_image_size: 100, - image_seq_len: 1, - }); + let config = Config::Idefics3(Idefics3 {}); let preprocessor_config = Some(&HubPreprocessorConfig::Idefics2Processor( Idefics2Preprocessor { @@ -1260,117 +1299,34 @@ mod tests { }, )); - let height = 100; - let width = 100; + let height = 1067; + let width = 1600; let tokens = image_tokens(&config, preprocessor_config, height, width); - assert_eq!( - tokens, - "" - ); - } + // get all unique tags `` from the tokens + let tags: std::collections::HashSet<&str> = tokens + .split(|c| c == '<' || c == '>') + .filter(|s| !s.is_empty()) + .collect(); - #[tokio::test] - async fn test_idefics3_image_tokens() { - let config = Config::Idefics3(Idefics3 { - vision_encoder_max_image_size: 980, - image_seq_len: 1, - }); + assert_eq!(tags.len(), 17); // all below and `\n` and `\n\n` + assert_eq!(tags.contains(&"row_1_col_1"), true); + assert_eq!(tags.contains(&"row_1_col_2"), true); + assert_eq!(tags.contains(&"row_1_col_3"), true); + assert_eq!(tags.contains(&"row_1_col_4"), true); + assert_eq!(tags.contains(&"row_2_col_1"), true); + assert_eq!(tags.contains(&"row_2_col_2"), true); + assert_eq!(tags.contains(&"row_2_col_3"), true); + assert_eq!(tags.contains(&"row_2_col_4"), true); + assert_eq!(tags.contains(&"row_3_col_1"), true); + assert_eq!(tags.contains(&"row_3_col_2"), true); + assert_eq!(tags.contains(&"row_3_col_3"), true); + assert_eq!(tags.contains(&"row_3_col_4"), true); + assert_eq!(tags.contains(&"global-img"), true); + assert_eq!(tags.contains(&"image"), true); + assert_eq!(tags.contains(&"fake_token_around_image"), true); - let preprocessor_config = Some(&HubPreprocessorConfig::Idefics3Processor( - Idefics2Preprocessor { - do_image_splitting: true, - }, - )); - - let height = 100; - let width = 100; - - let tokens = image_tokens(&config, preprocessor_config, height, width); - - assert_eq!( - tokens, - "" - ); - } - - #[tokio::test] - async fn test_idefics3_correct_n_fake_tokens() { - let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap(); - - let tokenizer = Some(get_tokenizer().await); - - let max_best_of = 2; - let max_stop_sequence = 3; - let max_top_n_tokens = 4; - let max_input_length = 5; - let max_total_tokens = 6; - let disable_grammar_support = true; - let workers = 1; - let config = Config::Idefics3(Idefics3 { - vision_encoder_max_image_size: 100, - image_seq_len: 1, - }); - let validation = Validation::new( - workers, - tokenizer, - Some(config), - Some(HubPreprocessorConfig::Idefics3Processor( - Idefics2Preprocessor { - do_image_splitting: true, - }, - )), - max_best_of, - max_stop_sequence, - max_top_n_tokens, - max_input_length, - max_total_tokens, - disable_grammar_support, - ); - - let (encoding, chunks) = match validation - .tokenize( - format!( - "test![](data:image/gif;base64,{})![](data:image/gif;base64,{})", - PIXEL_GIF, PIXEL_GIF - ), - None, - ) - .await - { - Ok(Some((encoding, chunks))) => (encoding, chunks), - _ => panic!("Unexpected tokenization failure"), - }; - - assert!( - chunks - == vec![ - Chunk::Text("test".to_string()).into(), - Chunk::Image(Image { - data: pixel_data.clone(), - mimetype: "image/gif".to_string() - }) - .into(), - Chunk::Image(Image { - data: pixel_data.clone(), - mimetype: "image/gif".to_string() - }) - .into() - ], - "Failed to process images", - ); - - // Verify the number of fake tokens: - // - // - Two images, each surrounded/separated by a fake token tags = 4. - assert_eq!( - encoding - .get_tokens() - .iter() - .filter(|t| *t == "fake") - .count(), - 4 - ); + assert_eq!(tokens.len(), 15_901) } } diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index a3968bf5..3723873c 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -753,6 +753,19 @@ class Idefics3ForConditionalGeneration(nn.Module): config.pad_token_id if config.pad_token_id is not None else -1 ) + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -835,25 +848,22 @@ class Idefics3ForConditionalGeneration(nn.Module): all_states.append(image_hidden_states) image_hidden_states = torch.stack(all_states, dim=0) - # When we generate, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - # TODO: finish implementing the image token replacement + # TODO: remove when prefill image tokens are handled correctly + # * for now dummy tokens are added instead of the image tokens output byt the vision model + mask_size = (input_ids == self.config.image_token_id).sum().item() + unrolled_image_size = ( + image_hidden_states.shape[1] * image_hidden_states.shape[2] + ) + diff = mask_size - unrolled_image_size + if diff > 0: + print( + f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}." + ) - # inputs_embeds = self.inputs_merger( - # input_ids=input_ids, - # inputs_embeds=inputs_embeds, - # image_hidden_states=image_hidden_states, - # ) - - # import ipdb; ipdb.set_trace() - # num_images, _, vision_hidden_size = image_hidden_states.shape - # special_image_token_mask = input_ids == self.image_token_id - # new_inputs_embeds = inputs_embeds.clone() - # reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to( - # inputs_embeds.dtype - # ) # cast to the dtype of the input_embeds to support quantized models - # new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states - # inputs_embeds = new_inputs_embeds + if mask_size == unrolled_image_size: + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index ffaacb1e..1bb6b74f 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -23,6 +23,75 @@ tracer = trace.get_tracer(__name__) IDEFICS2_FAKE_TOKEN = "" IDEFICS2_IMAGE_TOKEN = "" +IDEFICS3_IMAGE_TOKEN = "" +IDEFICS3_FAKE_IMAGE_TOKEN = "" +IDEFICS3_GLOBAL_IMG_TOKEN = "" + + +def _prompt_split_image( + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + + +def _prompt_single_image( + image_seq_len, fake_token_around_image, image_token, global_img_token +): + """Prompt with expanded image tokens for a single image.""" + return ( + f"{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + + +def get_image_prompt_string( + image_rows, + image_cols, + image_seq_len, + fake_token_around_image, + image_token, + global_img_token, +): + if image_rows == 0 and image_cols == 0: + return _prompt_single_image( + image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_img_token=global_img_token, + ) + return _prompt_split_image( + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, + ) + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -55,8 +124,22 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str image_str *= 5 return image_str if config.model_type == "idefics3": - image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}" - image_str = "" + # TODO: implement this in a more general way + n_rows = image_input["rows"][0][image_id] + n_cols = image_input["cols"][0][image_id] + + # TODO: avoid using hardcoded values + image_seq_len = 169 # default value + # image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2)) + + image_str = get_image_prompt_string( + n_rows, + n_cols, + image_seq_len, + image_token=IDEFICS3_IMAGE_TOKEN, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + ) return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] @@ -80,6 +163,10 @@ def image_text_replacement_fixup(config, text: str) -> str: return text.replace( f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN ) + if config.model_type == "idefics3": + return text.replace( + f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN + ) return text @@ -182,7 +269,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch): raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: - image_inputs = processor.image_processor(images, return_tensors="pt") + image_inputs = processor.image_processor( + images, return_tensors="pt", return_row_col_info=True + ) else: image_inputs = None @@ -196,9 +285,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - full_text += image_text_replacement( + replacement_text = image_text_replacement( processor, image_inputs, config, image_id ) + full_text += replacement_text image_id += 1 full_text = image_text_replacement_fixup(config, full_text) @@ -268,7 +358,7 @@ class VlmCausalLM(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code, - **processor_kwargs, + # **processor_kwargs, ) self.batch_class = batch_class # import ipdb; ipdb.set_trace()