More specific codes.
This commit is contained in:
parent
a6cd5fef23
commit
f0b35f94b8
|
@ -47,6 +47,7 @@ struct Config {
|
|||
max_position_embeddings: Option<usize>,
|
||||
quantize: Option<Quantization>,
|
||||
head_dim: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
}
|
||||
|
||||
impl From<RawConfig> for Config {
|
||||
|
@ -72,10 +73,12 @@ impl From<RawConfig> for Config {
|
|||
_ => None,
|
||||
}
|
||||
});
|
||||
let model_type = other.model_type;
|
||||
Config {
|
||||
max_position_embeddings,
|
||||
quantize,
|
||||
head_dim,
|
||||
model_type,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1492,10 +1495,6 @@ fn main() -> Result<(), LauncherError> {
|
|||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: RawConfig = serde_json::from_str(&content)?;
|
||||
|
||||
if config.model_type == Some("gemma2".to_string()) {
|
||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||
std::env::set_var("ATTENTION", "flashdecoding");
|
||||
}
|
||||
let config: Config = config.into();
|
||||
match config.head_dim {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
|
@ -1504,6 +1503,15 @@ fn main() -> Result<(), LauncherError> {
|
|||
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||
std::env::set_var("USE_PREFIX_CACHING", "0");
|
||||
}
|
||||
match config.model_type.as_deref() {
|
||||
Some("gemma2") | Some("falcon") => {
|
||||
// Required because gemma2 needs bfloat16 which is not supported by
|
||||
// flashinfer ?
|
||||
std::env::set_var("USE_PREFIX_CACHING", "0");
|
||||
std::env::set_var("ATTENTION", "flashdecoding");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
|
|
Loading…
Reference in New Issue