Fix the prefix for OPT model in opt_modelling.py #2370 (CI RUN) (#2371)

* Fix the bug

* fix: run lints

* fix: small syntax tweak

---------

Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com>
This commit is contained in:
drbh 2024-08-07 23:14:02 -04:00 committed by GitHub
parent 21267f3ca3
commit a379d5536b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 12 deletions

View File

@ -0,0 +1,19 @@
import pytest
@pytest.fixture(scope="module")
def opt_sharded_handle(launcher):
with launcher("facebook/opt-6.7b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def opt_sharded(opt_sharded_handle):
await opt_sharded_handle.health(300)
return opt_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_opt(opt_sharded):
pass

View File

@ -98,7 +98,9 @@ class OPTLearnedPositionalEmbedding(nn.Module):
super().__init__() super().__init__()
self.offset = 2 self.offset = 2
self.weight = nn.Parameter( self.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight") weights.get_tensor(
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
)
) )
def forward( def forward(
@ -315,7 +317,7 @@ class OPTDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
prefix = f"{prefix}.decoder.layers.{layer_id}" prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
self.self_attn = OPTAttention( self.self_attn = OPTAttention(
config, config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -437,15 +439,17 @@ class OPTDecoder(OPTPreTrainedModel):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
prefix = prefix + "." if prefix else ""
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.decoder.embed_tokens", weights=weights prefix=f"{prefix}decoder.embed_tokens", weights=weights
) )
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load( self.project_out = FastLinear.load(
config, config,
prefix=f"{prefix}.decoder.project_out", prefix=f"{prefix}decoder.project_out",
weights=weights, weights=weights,
bias=False, bias=False,
) )
@ -455,7 +459,7 @@ class OPTDecoder(OPTPreTrainedModel):
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load( self.project_in = FastLinear.load(
config, config,
prefix=f"{prefix}.decoder.project_in", prefix=f"{prefix}decoder.project_in",
weights=weights, weights=weights,
bias=False, bias=False,
) )
@ -467,7 +471,7 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164 # see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm: if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load( self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
) )
else: else:
self.final_layer_norm = None self.final_layer_norm = None
@ -752,15 +756,12 @@ class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__(config) super().__init__(config)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(prefix, config, weights) self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights config,
prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
weights=weights,
) )
def forward( def forward(