Removing encoder_decoder (seq2seq).

This commit is contained in:
Nicolas Patry 2024-08-27 21:11:49 +02:00
parent ccaf1d0030
commit 8ac1ffa087
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
1 changed files with 13 additions and 3 deletions

View File

@ -68,9 +68,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config { if let Some(config) = config {
if config.vision_config.is_some() && prefix_caching.is_none() { if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model"); tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string()); prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
} }
match config.head_dim { match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => { Some(h) if h == 64 || h == 128 || h == 256 => {
@ -90,6 +95,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
attention = Some("flashdecoding".to_string()); attention = Some("flashdecoding".to_string());
} }
} }
Some("t5") => {}
_ => {} _ => {}
} }
} }
@ -118,6 +124,7 @@ struct RawConfig {
num_attention_heads: Option<usize>, num_attention_heads: Option<usize>,
head_dim: Option<usize>, head_dim: Option<usize>,
vision_config: Option<VisionConfig>, vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -135,6 +142,7 @@ struct Config {
head_dim: Option<usize>, head_dim: Option<usize>,
model_type: Option<String>, model_type: Option<String>,
vision_config: Option<VisionConfig>, vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
} }
impl From<RawConfig> for Config { impl From<RawConfig> for Config {
@ -162,12 +170,14 @@ impl From<RawConfig> for Config {
}); });
let model_type = other.model_type; let model_type = other.model_type;
let vision_config = other.vision_config; let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config { Config {
max_position_embeddings, max_position_embeddings,
quantize, quantize,
head_dim, head_dim,
model_type, model_type,
vision_config, vision_config,
is_encoder_decoder,
} }
} }
} }