From cce475a9491bc011f8f0e89a2ef8d3a1bef88e74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 09:52:12 +0200 Subject: [PATCH] hotfix: Fix number of KV heads (#2202) Fix number of KV heads --- server/text_generation_server/models/flash_causal_lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e66011a1..42b2f686 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -906,8 +906,8 @@ class FlashCausalLM(Model): # Validation is done in the model itself if num_kv_heads is None: # Order is important here. - for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: - num_kv_heads = getattr(config, "num_attention_heads", None) + for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]: + num_kv_heads = getattr(config, attr, None) if num_kv_heads is not None: break if num_kv_heads is None: