use static rotary embedding

This commit is contained in:
Dean Wyatte 2024-01-26 22:22:55 +00:00
parent 069895b985
commit f2a00be169
1 changed files with 5 additions and 2 deletions

View File

@ -98,8 +98,11 @@ class FlashNeoxAttention(torch.nn.Module):
)
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rotary_emb_base,
device=weights.device,
)
self.softmax_scale = self.head_size ** (-0.5)