parent
521d0d990f
commit
cce475a949
|
@ -906,8 +906,8 @@ class FlashCausalLM(Model):
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
# Order is important here.
|
# Order is important here.
|
||||||
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]:
|
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
|
||||||
num_kv_heads = getattr(config, "num_attention_heads", None)
|
num_kv_heads = getattr(config, attr, None)
|
||||||
if num_kv_heads is not None:
|
if num_kv_heads is not None:
|
||||||
break
|
break
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
|
|
Loading…
Reference in New Issue