fix: default num_ln_in_parallel_attn to one if not supplied (#2364)
This commit is contained in:
parent
1768c00b9f
commit
a64d407d64
|
@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module):
|
|||
class FlashRWLayerNorm(nn.Module):
|
||||
def __init__(self, config, prefix: str, weights):
|
||||
super().__init__()
|
||||
self.num_ln = config.num_ln_in_parallel_attn
|
||||
# Falcon2 includes the number of layer norms in the config
|
||||
# in the case no number of layer norms is provided, we default to 1
|
||||
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
|
||||
|
||||
if self.num_ln == 1:
|
||||
self.input_ln = FastLayerNorm.load(
|
||||
|
|
Loading…
Reference in New Issue