More specific codes.
This commit is contained in:
parent
f5ee062cbd
commit
bd0ced354d
|
@ -47,6 +47,7 @@ struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
|
model_type: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<RawConfig> for Config {
|
impl From<RawConfig> for Config {
|
||||||
|
@ -72,10 +73,12 @@ impl From<RawConfig> for Config {
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
let model_type = other.model_type;
|
||||||
Config {
|
Config {
|
||||||
max_position_embeddings,
|
max_position_embeddings,
|
||||||
quantize,
|
quantize,
|
||||||
head_dim,
|
head_dim,
|
||||||
|
model_type,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1492,10 +1495,6 @@ fn main() -> Result<(), LauncherError> {
|
||||||
let content = std::fs::read_to_string(filename)?;
|
let content = std::fs::read_to_string(filename)?;
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
if config.model_type == Some("gemma2".to_string()) {
|
|
||||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
|
||||||
std::env::set_var("ATTENTION", "flashdecoding");
|
|
||||||
}
|
|
||||||
let config: Config = config.into();
|
let config: Config = config.into();
|
||||||
match config.head_dim {
|
match config.head_dim {
|
||||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||||
|
@ -1504,6 +1503,15 @@ fn main() -> Result<(), LauncherError> {
|
||||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||||
std::env::set_var("USE_PREFIX_CACHING", "0");
|
std::env::set_var("USE_PREFIX_CACHING", "0");
|
||||||
}
|
}
|
||||||
|
match config.model_type.as_deref() {
|
||||||
|
Some("gemma2") | Some("falcon") => {
|
||||||
|
// Required because gemma2 needs bfloat16 which is not supported by
|
||||||
|
// flashinfer ?
|
||||||
|
std::env::set_var("USE_PREFIX_CACHING", "0");
|
||||||
|
std::env::set_var("ATTENTION", "flashdecoding");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||||
|
|
Loading…
Reference in New Issue