From f2a00be169b11b569f99d47b4aad795ffb2e7f6d Mon Sep 17 00:00:00 2001 From: Dean Wyatte Date: Fri, 26 Jan 2024 22:22:55 +0000 Subject: [PATCH] use static rotary embedding --- .../models/custom_modeling/flash_neox_modeling.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index eea5f787..c58ba5fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -98,8 +98,11 @@ class FlashNeoxAttention(torch.nn.Module): ) self.num_heads = self.num_heads // weights.process_group.size() - self.rotary_emb = PositionRotaryEmbedding.load( - config=config, prefix=f"{prefix}.rotary_emb", weights=weights + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rotary_emb_base, + device=weights.device, ) self.softmax_scale = self.head_size ** (-0.5)