parent
521d0d990f
commit
cce475a949
|
@ -906,8 +906,8 @@ class FlashCausalLM(Model):
|
|||
# Validation is done in the model itself
|
||||
if num_kv_heads is None:
|
||||
# Order is important here.
|
||||
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]:
|
||||
num_kv_heads = getattr(config, "num_attention_heads", None)
|
||||
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
|
||||
num_kv_heads = getattr(config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
break
|
||||
if num_kv_heads is None:
|
||||
|
|
Loading…
Reference in New Issue