Falcon/DBRX: get correct number of key-value heads (#2205)

This commit is contained in:
Daniël de Kok 2024-07-08 13:22:38 +02:00 committed by GitHub
parent 153fcf7739
commit 5c7c9f1390
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 6 deletions

View File

@ -797,6 +797,10 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,

View File

@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig):
class DbrxConfig(PretrainedConfig):
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "n_heads",
"num_hidden_layers": "n_layers",
}
def __init__(
self,
d_model: int = 2048,
@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig):
**kwargs,
)
@property
def num_key_value_heads(self):
# We can't use the attribute map, since this the number of KV
# heads is not top-level.
return self.attn_config.kv_n_heads
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x

View File

@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig):
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
"num_key_value_heads": "n_head_kv",
}
def __init__(

View File

@ -905,13 +905,12 @@ class FlashCausalLM(Model):
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
# Order is important here.
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
num_kv_heads = getattr(config, attr, None)
if num_kv_heads is not None:
break
num_kv_heads = getattr(config, "num_key_value_heads", None)
# GPT-2 workaround
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
num_kv_heads = getattr(config, "n_head", None)
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1