From 871e5e73389091cd5d38507c64d4a8609cf5e5e3 Mon Sep 17 00:00:00 2001 From: Dean Wyatte Date: Tue, 30 Jan 2024 03:53:08 +0000 Subject: [PATCH] fix rotary dim --- .../models/custom_modeling/flash_neox_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index c58ba5fe..3ee344e4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.rotary_dim = int(config.rotary_pct * self.head_size) + if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " @@ -100,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb = PositionRotaryEmbedding.static( config=config, - dim=self.head_size, + dim=self.rotary_dim, base=config.rotary_emb_base, device=weights.device, )