vlm fixes

This commit is contained in:
Daniël de Kok 2024-09-03 10:30:15 +00:00
parent e6c524c66b
commit 029d2719c1
4 changed files with 16 additions and 10 deletions

View File

@ -58,7 +58,10 @@ fn get_config(
}; };
let content = std::fs::read_to_string(filename)?; let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?; let mut config: RawConfig = serde_json::from_str(&content)?;
if let Some(text_config) = config.text_config {
config = *text_config;
}
let config: Config = config.into(); let config: Config = config.into();
Ok(config) Ok(config)
@ -79,7 +82,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string()); prefix_caching = Some("0".to_string());
} }
match config.model_type.as_deref() { match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { Some("gemma") | Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by // Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ? // flashinfer ?
if attention.is_none() { if attention.is_none() {
@ -96,7 +99,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
_ => { _ => {
if attention.is_none() { if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); tracing::info!("Forcing flash decoding because head dim ({:?}) is not supported by flashinfer, also disabling prefix caching", config.head_dim);
attention = Some("flashdecoding".to_string()); attention = Some("flashdecoding".to_string());
} }
if prefix_caching.is_none() { if prefix_caching.is_none() {
@ -122,6 +125,7 @@ struct RawConfig {
num_attention_heads: Option<usize>, num_attention_heads: Option<usize>,
head_dim: Option<usize>, head_dim: Option<usize>,
is_encoder_decoder: Option<bool>, is_encoder_decoder: Option<bool>,
text_config: Option<Box<RawConfig>>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]

View File

@ -7,6 +7,7 @@ pub struct LlavaNext {
pub(crate) text_config: TextConfig, pub(crate) text_config: TextConfig,
pub(crate) vision_config: VisionConfig, pub(crate) vision_config: VisionConfig,
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>, pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
pub(crate) image_token_index: u32,
} }
fn get_anyres_image_grid_shape( fn get_anyres_image_grid_shape(
@ -112,7 +113,9 @@ pub struct ClipVisionModel {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Idefics2 {} pub struct Idefics2 {
pub(crate) image_token_id: u32,
}
impl Idefics2 { impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {

View File

@ -1,5 +1,5 @@
/// Payload validation logic /// Payload validation logic
use crate::config::Config; use crate::config::{Config, Idefics2};
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{ use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
@ -605,7 +605,9 @@ fn image_tokens(
fn image_id(config: &Config) -> u32 { fn image_id(config: &Config) -> u32 {
use Config::*; use Config::*;
match config { match config {
Paligemma(pali_gemma) => pali_gemma.image_token_index, Idefics2(idefics) => idefics.image_token_id,
LlavaNext(llava) => llava.image_token_index,
Paligemma(paligemma) => paligemma.image_token_index,
_ => unimplemented!("Images tokens are not supported for this model configuration"), _ => unimplemented!("Images tokens are not supported for this model configuration"),
} }
} }

View File

@ -256,8 +256,6 @@ class VlmCausalLM(FlashCausalLM):
trust_remote_code: bool, trust_remote_code: bool,
**kwargs, **kwargs,
): ):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None: if processor_kwargs is None:
processor_kwargs = {} processor_kwargs = {}
self.processor = processor_class.from_pretrained( self.processor = processor_class.from_pretrained(
@ -347,7 +345,6 @@ class VlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@ -358,7 +355,7 @@ class VlmCausalLM(FlashCausalLM):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING: if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,