fix: default num_ln_in_parallel_attn to one if not supplied (#2364)
This commit is contained in:
parent
1768c00b9f
commit
a64d407d64
|
@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module):
|
||||||
class FlashRWLayerNorm(nn.Module):
|
class FlashRWLayerNorm(nn.Module):
|
||||||
def __init__(self, config, prefix: str, weights):
|
def __init__(self, config, prefix: str, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_ln = config.num_ln_in_parallel_attn
|
# Falcon2 includes the number of layer norms in the config
|
||||||
|
# in the case no number of layer norms is provided, we default to 1
|
||||||
|
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
|
||||||
|
|
||||||
if self.num_ln == 1:
|
if self.num_ln == 1:
|
||||||
self.input_ln = FastLayerNorm.load(
|
self.input_ln = FastLayerNorm.load(
|
||||||
|
|
Loading…
Reference in New Issue