Use flashinfer for Gemma 2.

This commit is contained in:
Daniël de Kok 2024-10-15 13:49:32 +00:00
parent cf04a43fb1
commit ce7e356561
1 changed files with 1 additions and 1 deletions

View File

@ -94,7 +94,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {