From 029d2719c1287027360180a2b6c52bf1350dc723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 3 Sep 2024 10:30:15 +0000 Subject: [PATCH] vlm fixes --- launcher/src/main.rs | 10 +++++++--- router/src/config.rs | 5 ++++- router/src/validation.rs | 6 ++++-- server/text_generation_server/models/vlm_causal_lm.py | 5 +---- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 49fb3998..656d3d0e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -58,7 +58,10 @@ fn get_config( }; let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; + let mut config: RawConfig = serde_json::from_str(&content)?; + if let Some(text_config) = config.text_config { + config = *text_config; + } let config: Config = config.into(); Ok(config) @@ -79,7 +82,7 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> prefix_caching = Some("0".to_string()); } match config.model_type.as_deref() { - Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { + Some("gemma") | Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { // Required because gemma2 needs bfloat16 which is not supported by // flashinfer ? if attention.is_none() { @@ -96,7 +99,7 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } _ => { if attention.is_none() { - tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); + tracing::info!("Forcing flash decoding because head dim ({:?}) is not supported by flashinfer, also disabling prefix caching", config.head_dim); attention = Some("flashdecoding".to_string()); } if prefix_caching.is_none() { @@ -122,6 +125,7 @@ struct RawConfig { num_attention_heads: Option, head_dim: Option, is_encoder_decoder: Option, + text_config: Option>, } #[derive(Deserialize)] diff --git a/router/src/config.rs b/router/src/config.rs index af148217..a28d0577 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -7,6 +7,7 @@ pub struct LlavaNext { pub(crate) text_config: TextConfig, pub(crate) vision_config: VisionConfig, pub(crate) image_grid_pinpoints: Vec<(usize, usize)>, + pub(crate) image_token_index: u32, } fn get_anyres_image_grid_shape( @@ -112,7 +113,9 @@ pub struct ClipVisionModel { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -pub struct Idefics2 {} +pub struct Idefics2 { + pub(crate) image_token_id: u32, +} impl Idefics2 { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { diff --git a/router/src/validation.rs b/router/src/validation.rs index fd57c26a..879e9bf3 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,5 +1,5 @@ /// Payload validation logic -use crate::config::Config; +use crate::config::{Config, Idefics2}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{ GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, @@ -605,7 +605,9 @@ fn image_tokens( fn image_id(config: &Config) -> u32 { use Config::*; match config { - Paligemma(pali_gemma) => pali_gemma.image_token_index, + Idefics2(idefics) => idefics.image_token_id, + LlavaNext(llava) => llava.image_token_index, + Paligemma(paligemma) => paligemma.image_token_index, _ => unimplemented!("Images tokens are not supported for this model configuration"), } } diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index d6cb36fa..da6d06d6 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -256,8 +256,6 @@ class VlmCausalLM(FlashCausalLM): trust_remote_code: bool, **kwargs, ): - if PREFIX_CACHING: - raise NotImplementedError("Vlm do not work with prefix caching yet") if processor_kwargs is None: processor_kwargs = {} self.processor = processor_class.from_pretrained( @@ -347,7 +345,6 @@ class VlmCausalLM(FlashCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -358,7 +355,7 @@ class VlmCausalLM(FlashCausalLM): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: + if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths,