More specific codes.

This commit is contained in:
Nicolas Patry 2024-08-20 12:05:40 +02:00
parent f5ee062cbd
commit bd0ced354d
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
1 changed files with 12 additions and 4 deletions

View File

@ -47,6 +47,7 @@ struct Config {
max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
}
impl From<RawConfig> for Config {
@ -72,10 +73,12 @@ impl From<RawConfig> for Config {
_ => None,
}
});
let model_type = other.model_type;
Config {
max_position_embeddings,
quantize,
head_dim,
model_type,
}
}
}
@ -1492,10 +1495,6 @@ fn main() -> Result<(), LauncherError> {
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("ATTENTION", "flashdecoding");
}
let config: Config = config.into();
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
@ -1504,6 +1503,15 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("Disabling prefix caching because of lora adapters");
std::env::set_var("USE_PREFIX_CACHING", "0");
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
std::env::set_var("USE_PREFIX_CACHING", "0");
std::env::set_var("ATTENTION", "flashdecoding");
}
_ => {}
}
}
_ => {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");