diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 1eb1d83d..22bdda3b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -68,9 +68,14 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); if let Some(config) = config { - if config.vision_config.is_some() && prefix_caching.is_none() { - tracing::info!("Disabling prefix caching because of VLM model"); - prefix_caching = Some("0".to_string()); + if prefix_caching.is_none() { + if config.vision_config.is_some() { + tracing::info!("Disabling prefix caching because of VLM model"); + prefix_caching = Some("0".to_string()); + } else if config.is_encoder_decoder { + tracing::info!("Disabling prefix caching because of seq2seq model"); + prefix_caching = Some("0".to_string()); + } } match config.head_dim { Some(h) if h == 64 || h == 128 || h == 256 => { @@ -90,6 +95,7 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> attention = Some("flashdecoding".to_string()); } } + Some("t5") => {} _ => {} } } @@ -118,6 +124,7 @@ struct RawConfig { num_attention_heads: Option, head_dim: Option, vision_config: Option, + is_encoder_decoder: Option, } #[derive(Deserialize)] @@ -135,6 +142,7 @@ struct Config { head_dim: Option, model_type: Option, vision_config: Option, + is_encoder_decoder: bool, } impl From for Config { @@ -162,12 +170,14 @@ impl From for Config { }); let model_type = other.model_type; let vision_config = other.vision_config; + let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); Config { max_position_embeddings, quantize, head_dim, model_type, vision_config, + is_encoder_decoder, } } }