fix: default num_ln_in_parallel_attn to one if not supplied (#2364)

This commit is contained in:
drbh 2024-08-06 13:33:22 -04:00 committed by GitHub
parent 1768c00b9f
commit a64d407d64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module): class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix: str, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.num_ln = config.num_ln_in_parallel_attn # Falcon2 includes the number of layer norms in the config
# in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
if self.num_ln == 1: if self.num_ln == 1:
self.input_ln = FastLayerNorm.load( self.input_ln = FastLayerNorm.load(