Fixing baichuan override. (#2158)
This commit is contained in:
parent
d0225b1015
commit
4f55f15840
|
@ -117,6 +117,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
# Setting defaults for baichuan custom config which doesn't apply them.
|
||||||
|
config.rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
config.num_key_value_heads = getattr(
|
||||||
|
config, "num_key_value_heads", config.num_attention_heads
|
||||||
|
)
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
|
|
Loading…
Reference in New Issue