hotfix: Fix number of KV heads (#2202)

Fix number of KV heads
This commit is contained in:
Daniël de Kok 2024-07-08 09:52:12 +02:00 committed by GitHub
parent 521d0d990f
commit cce475a949
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -906,8 +906,8 @@ class FlashCausalLM(Model):
# Validation is done in the model itself # Validation is done in the model itself
if num_kv_heads is None: if num_kv_heads is None:
# Order is important here. # Order is important here.
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
num_kv_heads = getattr(config, "num_attention_heads", None) num_kv_heads = getattr(config, attr, None)
if num_kv_heads is not None: if num_kv_heads is not None:
break break
if num_kv_heads is None: if num_kv_heads is None: