diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index c58ba5fe..3ee344e4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.rotary_dim = int(config.rotary_pct * self.head_size) + if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " @@ -100,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb = PositionRotaryEmbedding.static( config=config, - dim=self.head_size, + dim=self.rotary_dim, base=config.rotary_emb_base, device=weights.device, )