use static rotary embedding
This commit is contained in:
parent
069895b985
commit
f2a00be169
|
@ -98,8 +98,11 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rotary_emb_base,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
|
Loading…
Reference in New Issue