Removing encoder_decoder (seq2seq).
This commit is contained in:
parent
ccaf1d0030
commit
8ac1ffa087
|
@ -68,9 +68,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
if let Some(config) = config {
|
if let Some(config) = config {
|
||||||
if config.vision_config.is_some() && prefix_caching.is_none() {
|
if prefix_caching.is_none() {
|
||||||
|
if config.vision_config.is_some() {
|
||||||
tracing::info!("Disabling prefix caching because of VLM model");
|
tracing::info!("Disabling prefix caching because of VLM model");
|
||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
|
} else if config.is_encoder_decoder {
|
||||||
|
tracing::info!("Disabling prefix caching because of seq2seq model");
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
match config.head_dim {
|
match config.head_dim {
|
||||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||||
|
@ -90,6 +95,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
attention = Some("flashdecoding".to_string());
|
attention = Some("flashdecoding".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Some("t5") => {}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,6 +124,7 @@ struct RawConfig {
|
||||||
num_attention_heads: Option<usize>,
|
num_attention_heads: Option<usize>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
vision_config: Option<VisionConfig>,
|
vision_config: Option<VisionConfig>,
|
||||||
|
is_encoder_decoder: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
@ -135,6 +142,7 @@ struct Config {
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
vision_config: Option<VisionConfig>,
|
vision_config: Option<VisionConfig>,
|
||||||
|
is_encoder_decoder: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<RawConfig> for Config {
|
impl From<RawConfig> for Config {
|
||||||
|
@ -162,12 +170,14 @@ impl From<RawConfig> for Config {
|
||||||
});
|
});
|
||||||
let model_type = other.model_type;
|
let model_type = other.model_type;
|
||||||
let vision_config = other.vision_config;
|
let vision_config = other.vision_config;
|
||||||
|
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||||
Config {
|
Config {
|
||||||
max_position_embeddings,
|
max_position_embeddings,
|
||||||
quantize,
|
quantize,
|
||||||
head_dim,
|
head_dim,
|
||||||
model_type,
|
model_type,
|
||||||
vision_config,
|
vision_config,
|
||||||
|
is_encoder_decoder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue