use static rotary embedding
This commit is contained in:
parent
069895b985
commit
f2a00be169
|
@ -98,8 +98,11 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rotary_emb_base,
|
||||||
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
Loading…
Reference in New Issue