fix rotary dim

This commit is contained in:
Dean Wyatte 2024-01-30 03:53:08 +00:00
parent f2a00be169
commit 871e5e7338
1 changed files with 3 additions and 1 deletions

View File

@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module):
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_dim = int(config.rotary_pct * self.head_size)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
@ -100,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
dim=self.rotary_dim,
base=config.rotary_emb_base,
device=weights.device,
)