Don't enable prefix caching on VLM just yet.

This commit is contained in:
Nicolas Patry 2024-08-27 09:58:19 +02:00
parent e30fb25444
commit f1c0735453
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
1 changed files with 4 additions and 4 deletions

View File

@ -68,16 +68,16 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if config.vision_config.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
}
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
if config.vision_config.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by