vlm fixes
This commit is contained in:
parent
e6c524c66b
commit
029d2719c1
|
@ -58,7 +58,10 @@ fn get_config(
|
|||
};
|
||||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: RawConfig = serde_json::from_str(&content)?;
|
||||
let mut config: RawConfig = serde_json::from_str(&content)?;
|
||||
if let Some(text_config) = config.text_config {
|
||||
config = *text_config;
|
||||
}
|
||||
|
||||
let config: Config = config.into();
|
||||
Ok(config)
|
||||
|
@ -79,7 +82,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
match config.model_type.as_deref() {
|
||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||
Some("gemma") | Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||
// Required because gemma2 needs bfloat16 which is not supported by
|
||||
// flashinfer ?
|
||||
if attention.is_none() {
|
||||
|
@ -96,7 +99,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
}
|
||||
_ => {
|
||||
if attention.is_none() {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
tracing::info!("Forcing flash decoding because head dim ({:?}) is not supported by flashinfer, also disabling prefix caching", config.head_dim);
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
if prefix_caching.is_none() {
|
||||
|
@ -122,6 +125,7 @@ struct RawConfig {
|
|||
num_attention_heads: Option<usize>,
|
||||
head_dim: Option<usize>,
|
||||
is_encoder_decoder: Option<bool>,
|
||||
text_config: Option<Box<RawConfig>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
|
|
@ -7,6 +7,7 @@ pub struct LlavaNext {
|
|||
pub(crate) text_config: TextConfig,
|
||||
pub(crate) vision_config: VisionConfig,
|
||||
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
|
||||
pub(crate) image_token_index: u32,
|
||||
}
|
||||
|
||||
fn get_anyres_image_grid_shape(
|
||||
|
@ -112,7 +113,9 @@ pub struct ClipVisionModel {
|
|||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Idefics2 {}
|
||||
pub struct Idefics2 {
|
||||
pub(crate) image_token_id: u32,
|
||||
}
|
||||
|
||||
impl Idefics2 {
|
||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/// Payload validation logic
|
||||
use crate::config::Config;
|
||||
use crate::config::{Config, Idefics2};
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{
|
||||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||
|
@ -605,7 +605,9 @@ fn image_tokens(
|
|||
fn image_id(config: &Config) -> u32 {
|
||||
use Config::*;
|
||||
match config {
|
||||
Paligemma(pali_gemma) => pali_gemma.image_token_index,
|
||||
Idefics2(idefics) => idefics.image_token_id,
|
||||
LlavaNext(llava) => llava.image_token_index,
|
||||
Paligemma(paligemma) => paligemma.image_token_index,
|
||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -256,8 +256,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||
trust_remote_code: bool,
|
||||
**kwargs,
|
||||
):
|
||||
if PREFIX_CACHING:
|
||||
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
||||
if processor_kwargs is None:
|
||||
processor_kwargs = {}
|
||||
self.processor = processor_class.from_pretrained(
|
||||
|
@ -347,7 +345,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
|
@ -358,7 +355,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
if PREFIX_CACHING:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
|
|
Loading…
Reference in New Issue