fix: prefer original layernorm names for 180B (#2365)
This commit is contained in:
parent
a64d407d64
commit
133015f408
|
@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module):
|
||||||
|
|
||||||
prefix = f"{prefix}.h.{layer_id}"
|
prefix = f"{prefix}.h.{layer_id}"
|
||||||
|
|
||||||
|
# NOTE: Falcon 180B uses the ln_attn prefix
|
||||||
|
ln_prefix = "input_layernorm"
|
||||||
|
if config.num_hidden_layers == 80:
|
||||||
|
ln_prefix = "ln_attn"
|
||||||
|
|
||||||
self.input_layernorm = FastLayerNorm.load(
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.{ln_prefix}",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
|
@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module):
|
||||||
# in the case no number of layer norms is provided, we default to 1
|
# in the case no number of layer norms is provided, we default to 1
|
||||||
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
|
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
|
||||||
|
|
||||||
|
# Falcon 180B uses the ln_attn prefix and has 2 layer norms
|
||||||
|
if config.num_hidden_layers == 80:
|
||||||
|
self.num_ln = 2
|
||||||
|
|
||||||
if self.num_ln == 1:
|
if self.num_ln == 1:
|
||||||
self.input_ln = FastLayerNorm.load(
|
self.input_ln = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
|
Loading…
Reference in New Issue