feat: support force downcast after FastRMSNorm multiply for Gemma (#1658)
This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is used in the Gemma model. References https://github.com/huggingface/transformers/pull/29402 and https://github.com/huggingface/transformers/pull/29729 Setting `force_downcast_after=True` will perform the `hidden_states * weight` multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the `hidden_states` to a half and then multiples.
This commit is contained in:
parent
dfbd9a39a2
commit
6f15ac60b2
|
@ -209,7 +209,7 @@ class GemmaConfig(PretrainedConfig):
|
|||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu",
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
|
@ -261,6 +261,17 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
|||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||
return cls(weight, eps)
|
||||
|
||||
# perform the multiplication in full precision and downcast after
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states.to(self.weight.dtype), residual
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
|
|
Loading…
Reference in New Issue