From a64d407d64de1ba168523df5391863b1f85c0824 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 13:33:22 -0400 Subject: [PATCH] fix: default num_ln_in_parallel_attn to one if not supplied (#2364) --- .../models/custom_modeling/flash_rw_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 708641e7..0691da9b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module): class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix: str, weights): super().__init__() - self.num_ln = config.num_ln_in_parallel_attn + # Falcon2 includes the number of layer norms in the config + # in the case no number of layer norms is provided, we default to 1 + self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) if self.num_ln == 1: self.input_ln = FastLayerNorm.load(