* Fix the bug * fix: run lints * fix: small syntax tweak --------- Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com>
This commit is contained in:
parent
21267f3ca3
commit
a379d5536b
|
@ -0,0 +1,19 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opt_sharded_handle(launcher):
|
||||
with launcher("facebook/opt-6.7b", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def opt_sharded(opt_sharded_handle):
|
||||
await opt_sharded_handle.health(300)
|
||||
return opt_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_opt(opt_sharded):
|
||||
pass
|
|
@ -98,7 +98,9 @@ class OPTLearnedPositionalEmbedding(nn.Module):
|
|||
super().__init__()
|
||||
self.offset = 2
|
||||
self.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
|
||||
weights.get_tensor(
|
||||
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -315,7 +317,7 @@ class OPTDecoderLayer(nn.Module):
|
|||
super().__init__()
|
||||
self.process_group = weights.process_group
|
||||
self.hidden_size = config.hidden_size
|
||||
prefix = f"{prefix}.decoder.layers.{layer_id}"
|
||||
prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
|
||||
self.self_attn = OPTAttention(
|
||||
config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
|
@ -437,15 +439,17 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||
self.max_target_positions = config.max_position_embeddings
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
prefix = prefix + "." if prefix else ""
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
||||
prefix=f"{prefix}decoder.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = FastLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.decoder.project_out",
|
||||
prefix=f"{prefix}decoder.project_out",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
@ -455,7 +459,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = FastLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.decoder.project_in",
|
||||
prefix=f"{prefix}decoder.project_in",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
@ -467,7 +471,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
|
||||
prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
@ -752,15 +756,12 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||
def __init__(self, prefix, config, weights):
|
||||
super().__init__(config)
|
||||
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.model = OPTModel(prefix, config, weights)
|
||||
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
||||
config,
|
||||
prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
Loading…
Reference in New Issue