From 13c62be467953c4762eeb15fa418bc0f6c716acd Mon Sep 17 00:00:00 2001 From: Dean Wyatte <2512762+dwyatte@users.noreply.github.com> Date: Thu, 1 Feb 2024 01:34:11 -0700 Subject: [PATCH] GPTNeoX: Use static rotary embedding (#1498) # What does this PR do? `transformers` 4.35 removed rotary embeddings from GPTNeoX's weights ([link to line diff](https://github.com/huggingface/transformers/commit/253f9a3f9716d08a81fb305fe71f983122eb608b#diff-0e2a05d86c82e96f516db8c14070ceb36f53ca44c6bc21a9cd92ad2e777b9cf1R298)). This applies the same fix as https://github.com/huggingface/text-generation-inference/pull/793 which generates them on-the-fly using the appropriate value from the config file Fixes https://github.com/huggingface/text-generation-inference/issues/1460 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [x] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? @OlivierDehaene OR @Narsil --- .../models/custom_modeling/flash_neox_modeling.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index eea5f787..3ee344e4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.rotary_dim = int(config.rotary_pct * self.head_size) + if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " @@ -98,8 +100,11 @@ class FlashNeoxAttention(torch.nn.Module): ) self.num_heads = self.num_heads // weights.process_group.size() - self.rotary_emb = PositionRotaryEmbedding.load( - config=config, prefix=f"{prefix}.rotary_emb", weights=weights + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.rotary_dim, + base=config.rotary_emb_base, + device=weights.device, ) self.softmax_scale = self.head_size ** (-0.5)