GPTNeoX: Use static rotary embedding (#1498)
# What does this PR do?
`transformers` 4.35 removed rotary embeddings from GPTNeoX's weights
([link to line
diff](253f9a3f97 (diff-0e2a05d86c82e96f516db8c14070ceb36f53ca44c6bc21a9cd92ad2e777b9cf1R298)
)).
This applies the same fix as
https://github.com/huggingface/text-generation-inference/pull/793 which
generates them on-the-fly using the appropriate value from the config
file
Fixes
https://github.com/huggingface/text-generation-inference/issues/1460
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [x] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
@OlivierDehaene OR @Narsil
This commit is contained in:
parent
2ae36a97fd
commit
13c62be467
|
@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_dim = int(config.rotary_pct * self.head_size)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
|
@ -98,8 +100,11 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.rotary_dim,
|
||||
base=config.rotary_emb_base,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
|
Loading…
Reference in New Issue