Removing encoder_decoder (seq2seq).

This commit is contained in:
Nicolas Patry 2024-08-27 21:11:49 +02:00
parent ccaf1d0030
commit 8ac1ffa087
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
1 changed files with 13 additions and 3 deletions

View File

@ -68,9 +68,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if config.vision_config.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
}
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
@ -90,6 +95,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
attention = Some("flashdecoding".to_string());
}
}
Some("t5") => {}
_ => {}
}
}
@ -118,6 +124,7 @@ struct RawConfig {
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
}
#[derive(Deserialize)]
@ -135,6 +142,7 @@ struct Config {
head_dim: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
}
impl From<RawConfig> for Config {
@ -162,12 +170,14 @@ impl From<RawConfig> for Config {
});
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config {
max_position_embeddings,
quantize,
head_dim,
model_type,
vision_config,
is_encoder_decoder,
}
}
}