fix rotary dim
This commit is contained in:
parent
f2a00be169
commit
871e5e7338
|
@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_dim = int(config.rotary_pct * self.head_size)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
|
@ -100,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
dim=self.rotary_dim,
|
||||
base=config.rotary_emb_base,
|
||||
device=weights.device,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue