feat: support force downcast after FastRMSNorm multiply for Gemma (#1658)

This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is
used in the Gemma model. References
https://github.com/huggingface/transformers/pull/29402 and
https://github.com/huggingface/transformers/pull/29729

Setting `force_downcast_after=True` will perform the `hidden_states *
weight` multiplication in f32 and then downcast to half. This differs
slightly from the current implementation which first casts the
`hidden_states` to a half and then multiples.
This commit is contained in:
drbh 2024-03-21 05:25:11 -04:00 committed by GitHub
parent dfbd9a39a2
commit 6f15ac60b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 1 deletions

View File

@ -209,7 +209,7 @@ class GemmaConfig(PretrainedConfig):
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu",
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
@ -261,6 +261,17 @@ class GemmaFastRMSNorm(FastRMSNorm):
weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps)
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.weight.dtype), residual
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: