fix: prefer original layernorm names for 180B (#2365)
This commit is contained in:
parent
a64d407d64
commit
133015f408
|
@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module):
|
|||
|
||||
prefix = f"{prefix}.h.{layer_id}"
|
||||
|
||||
# NOTE: Falcon 180B uses the ln_attn prefix
|
||||
ln_prefix = "input_layernorm"
|
||||
if config.num_hidden_layers == 80:
|
||||
ln_prefix = "ln_attn"
|
||||
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
prefix=f"{prefix}.{ln_prefix}",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module):
|
|||
# in the case no number of layer norms is provided, we default to 1
|
||||
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
|
||||
|
||||
# Falcon 180B uses the ln_attn prefix and has 2 layer norms
|
||||
if config.num_hidden_layers == 80:
|
||||
self.num_ln = 2
|
||||
|
||||
if self.num_ln == 1:
|
||||
self.input_ln = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
|
|
Loading…
Reference in New Issue