Fixing mistral nemo. (#2276)

This commit is contained in:
Nicolas Patry 2024-07-23 11:16:03 +02:00 committed by GitHub
parent 4700465192
commit abc32537ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 4 deletions

View File

@ -762,8 +762,6 @@ def get_model(
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
# hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it`
head_size=config_dict["head_dim"],
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))

View File

@ -117,7 +117,10 @@ class MistralAttention(torch.nn.Module):
) )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,

View File

@ -925,7 +925,12 @@ class FlashCausalLM(Model):
assert self.num_kv_heads > 0 assert self.num_kv_heads > 0
if head_size is None: if head_size is None:
self.head_size = config.hidden_size // config.num_attention_heads # Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
else: else:
self.head_size = head_size self.head_size = head_size