vlm fixes

This commit is contained in:
Daniël de Kok 2024-09-03 10:30:15 +00:00
parent e6c524c66b
commit 029d2719c1
4 changed files with 16 additions and 10 deletions

View File

@ -58,7 +58,10 @@ fn get_config(
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let mut config: RawConfig = serde_json::from_str(&content)?;
if let Some(text_config) = config.text_config {
config = *text_config;
}
let config: Config = config.into();
Ok(config)
@ -79,7 +82,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
Some("gemma") | Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {
@ -96,7 +99,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
_ => {
if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
tracing::info!("Forcing flash decoding because head dim ({:?}) is not supported by flashinfer, also disabling prefix caching", config.head_dim);
attention = Some("flashdecoding".to_string());
}
if prefix_caching.is_none() {
@ -122,6 +125,7 @@ struct RawConfig {
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
is_encoder_decoder: Option<bool>,
text_config: Option<Box<RawConfig>>,
}
#[derive(Deserialize)]

View File

@ -7,6 +7,7 @@ pub struct LlavaNext {
pub(crate) text_config: TextConfig,
pub(crate) vision_config: VisionConfig,
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
pub(crate) image_token_index: u32,
}
fn get_anyres_image_grid_shape(
@ -112,7 +113,9 @@ pub struct ClipVisionModel {
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
pub struct Idefics2 {
pub(crate) image_token_id: u32,
}
impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {

View File

@ -1,5 +1,5 @@
/// Payload validation logic
use crate::config::Config;
use crate::config::{Config, Idefics2};
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
@ -605,7 +605,9 @@ fn image_tokens(
fn image_id(config: &Config) -> u32 {
use Config::*;
match config {
Paligemma(pali_gemma) => pali_gemma.image_token_index,
Idefics2(idefics) => idefics.image_token_id,
LlavaNext(llava) => llava.image_token_index,
Paligemma(paligemma) => paligemma.image_token_index,
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}

View File

@ -256,8 +256,6 @@ class VlmCausalLM(FlashCausalLM):
trust_remote_code: bool,
**kwargs,
):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
@ -347,7 +345,6 @@ class VlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@ -358,7 +355,7 @@ class VlmCausalLM(FlashCausalLM):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,