fix: prefer original layernorm names for 180B (#2365)

This commit is contained in:
drbh 2024-08-06 15:25:30 -04:00 committed by GitHub
parent a64d407d64
commit 133015f408
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 1 deletions

View File

@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module):
prefix = f"{prefix}.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
# NOTE: Falcon 180B uses the ln_attn prefix
ln_prefix = "input_layernorm"
if config.num_hidden_layers == 80:
ln_prefix = "ln_attn"
self.input_layernorm = FastLayerNorm.load( self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.{ln_prefix}",
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module):
# in the case no number of layer norms is provided, we default to 1 # in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
# Falcon 180B uses the ln_attn prefix and has 2 layer norms
if config.num_hidden_layers == 80:
self.num_ln = 2
if self.num_ln == 1: if self.num_ln == 1:
self.input_ln = FastLayerNorm.load( self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",