* Fix the bug * fix: run lints * fix: small syntax tweak --------- Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com>
This commit is contained in:
parent
21267f3ca3
commit
a379d5536b
|
@ -0,0 +1,19 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def opt_sharded_handle(launcher):
|
||||||
|
with launcher("facebook/opt-6.7b", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def opt_sharded(opt_sharded_handle):
|
||||||
|
await opt_sharded_handle.health(300)
|
||||||
|
return opt_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_opt(opt_sharded):
|
||||||
|
pass
|
|
@ -98,7 +98,9 @@ class OPTLearnedPositionalEmbedding(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.offset = 2
|
self.offset = 2
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(
|
||||||
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
|
weights.get_tensor(
|
||||||
|
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -315,7 +317,7 @@ class OPTDecoderLayer(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
prefix = f"{prefix}.decoder.layers.{layer_id}"
|
prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
|
||||||
self.self_attn = OPTAttention(
|
self.self_attn = OPTAttention(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
@ -437,15 +439,17 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||||
self.max_target_positions = config.max_position_embeddings
|
self.max_target_positions = config.max_position_embeddings
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
prefix = prefix + "." if prefix else ""
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
prefix=f"{prefix}decoder.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
|
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
if config.word_embed_proj_dim != config.hidden_size:
|
||||||
self.project_out = FastLinear.load(
|
self.project_out = FastLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.decoder.project_out",
|
prefix=f"{prefix}decoder.project_out",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
@ -455,7 +459,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
if config.word_embed_proj_dim != config.hidden_size:
|
||||||
self.project_in = FastLinear.load(
|
self.project_in = FastLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.decoder.project_in",
|
prefix=f"{prefix}decoder.project_in",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
@ -467,7 +471,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||||
self.final_layer_norm = nn.LayerNorm.load(
|
self.final_layer_norm = nn.LayerNorm.load(
|
||||||
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
|
prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.final_layer_norm = None
|
self.final_layer_norm = None
|
||||||
|
@ -752,15 +756,12 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not prefix:
|
|
||||||
prefix = "model"
|
|
||||||
else:
|
|
||||||
prefix = f"{prefix}.model"
|
|
||||||
|
|
||||||
self.model = OPTModel(prefix, config, weights)
|
self.model = OPTModel(prefix, config, weights)
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
config,
|
||||||
|
prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
Loading…
Reference in New Issue