fix facebook/opt-125m not working issue (#2824)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-12-12 21:41:30 +08:00 committed by GitHub
parent c3bd7212c2
commit bf59118a93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -99,7 +99,7 @@ class OPTLearnedPositionalEmbedding(nn.Module):
self.offset = 2 self.offset = 2
self.weight = nn.Parameter( self.weight = nn.Parameter(
weights.get_tensor( weights.get_tensor(
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" f"{prefix if prefix else ''}decoder.embed_positions.weight"
) )
) )
@ -317,7 +317,7 @@ class OPTDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}"
self.self_attn = OPTAttention( self.self_attn = OPTAttention(
config, config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -755,6 +755,8 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__(config) super().__init__(config)
if not prefix and any(s.startswith("model") for s in weights.routing.keys()):
prefix = "model"
self.model = OPTModel(prefix, config, weights) self.model = OPTModel(prefix, config, weights)