diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 708641e7..0691da9b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module): class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix: str, weights): super().__init__() - self.num_ln = config.num_ln_in_parallel_attn + # Falcon2 includes the number of layer norms in the config + # in the case no number of layer norms is provided, we default to 1 + self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) if self.num_ln == 1: self.input_ln = FastLayerNorm.load(