vlm fixes
This commit is contained in:
parent
e6c524c66b
commit
029d2719c1
|
@ -58,7 +58,10 @@ fn get_config(
|
||||||
};
|
};
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
let mut config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
if let Some(text_config) = config.text_config {
|
||||||
|
config = *text_config;
|
||||||
|
}
|
||||||
|
|
||||||
let config: Config = config.into();
|
let config: Config = config.into();
|
||||||
Ok(config)
|
Ok(config)
|
||||||
|
@ -79,7 +82,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
}
|
||||||
match config.model_type.as_deref() {
|
match config.model_type.as_deref() {
|
||||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
Some("gemma") | Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||||
// Required because gemma2 needs bfloat16 which is not supported by
|
// Required because gemma2 needs bfloat16 which is not supported by
|
||||||
// flashinfer ?
|
// flashinfer ?
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
|
@ -96,7 +99,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
tracing::info!("Forcing flash decoding because head dim ({:?}) is not supported by flashinfer, also disabling prefix caching", config.head_dim);
|
||||||
attention = Some("flashdecoding".to_string());
|
attention = Some("flashdecoding".to_string());
|
||||||
}
|
}
|
||||||
if prefix_caching.is_none() {
|
if prefix_caching.is_none() {
|
||||||
|
@ -122,6 +125,7 @@ struct RawConfig {
|
||||||
num_attention_heads: Option<usize>,
|
num_attention_heads: Option<usize>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
is_encoder_decoder: Option<bool>,
|
is_encoder_decoder: Option<bool>,
|
||||||
|
text_config: Option<Box<RawConfig>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
|
|
@ -7,6 +7,7 @@ pub struct LlavaNext {
|
||||||
pub(crate) text_config: TextConfig,
|
pub(crate) text_config: TextConfig,
|
||||||
pub(crate) vision_config: VisionConfig,
|
pub(crate) vision_config: VisionConfig,
|
||||||
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
|
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
|
||||||
|
pub(crate) image_token_index: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_anyres_image_grid_shape(
|
fn get_anyres_image_grid_shape(
|
||||||
|
@ -112,7 +113,9 @@ pub struct ClipVisionModel {
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Idefics2 {}
|
pub struct Idefics2 {
|
||||||
|
pub(crate) image_token_id: u32,
|
||||||
|
}
|
||||||
|
|
||||||
impl Idefics2 {
|
impl Idefics2 {
|
||||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::config::Config;
|
use crate::config::{Config, Idefics2};
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{
|
use crate::{
|
||||||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||||
|
@ -605,7 +605,9 @@ fn image_tokens(
|
||||||
fn image_id(config: &Config) -> u32 {
|
fn image_id(config: &Config) -> u32 {
|
||||||
use Config::*;
|
use Config::*;
|
||||||
match config {
|
match config {
|
||||||
Paligemma(pali_gemma) => pali_gemma.image_token_index,
|
Idefics2(idefics) => idefics.image_token_id,
|
||||||
|
LlavaNext(llava) => llava.image_token_index,
|
||||||
|
Paligemma(paligemma) => paligemma.image_token_index,
|
||||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,8 +256,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if PREFIX_CACHING:
|
|
||||||
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
|
||||||
if processor_kwargs is None:
|
if processor_kwargs is None:
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
self.processor = processor_class.from_pretrained(
|
self.processor = processor_class.from_pretrained(
|
||||||
|
@ -347,7 +345,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
@ -358,7 +355,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
input_lengths = input_lengths + prefix_lens_tensor
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
if PREFIX_CACHING:
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
|
|
Loading…
Reference in New Issue