From 6f15ac60b2b7885a231af507dbbb8e4c57cb9e8b Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 21 Mar 2024 05:25:11 -0400 Subject: [PATCH] feat: support force downcast after FastRMSNorm multiply for Gemma (#1658) This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is used in the Gemma model. References https://github.com/huggingface/transformers/pull/29402 and https://github.com/huggingface/transformers/pull/29729 Setting `force_downcast_after=True` will perform the `hidden_states * weight` multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the `hidden_states` to a half and then multiples. --- .../models/custom_modeling/flash_gemma_modeling.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index e91927df..69c1665d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -209,7 +209,7 @@ class GemmaConfig(PretrainedConfig): num_attention_heads=16, num_key_value_heads=16, head_dim=256, - hidden_act="gelu", + hidden_act="gelu_pytorch_tanh", max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, @@ -261,6 +261,17 @@ class GemmaFastRMSNorm(FastRMSNorm): weight = weights.get_tensor(f"{prefix}.weight") + 1 return cls(weight, eps) + # perform the multiplication in full precision and downcast after + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * self.weight + return hidden_states.to(self.weight.dtype), residual + def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: