diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index eea5f787..c58ba5fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -98,8 +98,11 @@ class FlashNeoxAttention(torch.nn.Module): ) self.num_heads = self.num_heads // weights.process_group.size() - self.rotary_emb = PositionRotaryEmbedding.load( - config=config, prefix=f"{prefix}.rotary_emb", weights=weights + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rotary_emb_base, + device=weights.device, ) self.softmax_scale = self.head_size ** (-0.5)