Fixing mistral nemo. (#2276)
This commit is contained in:
parent
4700465192
commit
abc32537ea
|
@ -762,8 +762,6 @@ def get_model(
|
|||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it`
|
||||
head_size=config_dict["head_dim"],
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||
|
|
|
@ -117,7 +117,10 @@ class MistralAttention(torch.nn.Module):
|
|||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
if hasattr(config, "head_dim"):
|
||||
self.head_size = config.head_dim
|
||||
else:
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
|
|
|
@ -925,7 +925,12 @@ class FlashCausalLM(Model):
|
|||
assert self.num_kv_heads > 0
|
||||
|
||||
if head_size is None:
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
# Some models use GQA and different sizes for o_proj
|
||||
# and q_proj, that allows for that.
|
||||
if hasattr(config, "head_dim"):
|
||||
self.head_size = config.head_dim
|
||||
else:
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
else:
|
||||
self.head_size = head_size
|
||||
|
||||
|
|
Loading…
Reference in New Issue