fix rotary dim
This commit is contained in:
parent
f2a00be169
commit
871e5e7338
|
@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
|
self.rotary_dim = int(config.rotary_pct * self.head_size)
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
@ -100,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.rotary_dim,
|
||||||
base=config.rotary_emb_base,
|
base=config.rotary_emb_base,
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue