From a379d5536bb2de55154dc09c3a1f24ce58cb7df5 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 7 Aug 2024 23:14:02 -0400 Subject: [PATCH] Fix the prefix for OPT model in opt_modelling.py #2370 (CI RUN) (#2371) * Fix the bug * fix: run lints * fix: small syntax tweak --------- Co-authored-by: Sadra Barikbin --- integration-tests/models/test_opt.py | 19 ++++++++++++++ .../models/custom_modeling/opt_modeling.py | 25 ++++++++++--------- 2 files changed, 32 insertions(+), 12 deletions(-) create mode 100644 integration-tests/models/test_opt.py diff --git a/integration-tests/models/test_opt.py b/integration-tests/models/test_opt.py new file mode 100644 index 00000000..cbeb6376 --- /dev/null +++ b/integration-tests/models/test_opt.py @@ -0,0 +1,19 @@ +import pytest + + +@pytest.fixture(scope="module") +def opt_sharded_handle(launcher): + with launcher("facebook/opt-6.7b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def opt_sharded(opt_sharded_handle): + await opt_sharded_handle.health(300) + return opt_sharded_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +async def test_opt(opt_sharded): + pass diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 84a1c069..bd440321 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -98,7 +98,9 @@ class OPTLearnedPositionalEmbedding(nn.Module): super().__init__() self.offset = 2 self.weight = nn.Parameter( - weights.get_tensor(f"{prefix}.decoder.embed_positions.weight") + weights.get_tensor( + f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" + ) ) def forward( @@ -315,7 +317,7 @@ class OPTDecoderLayer(nn.Module): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"{prefix}.decoder.layers.{layer_id}" + prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -437,15 +439,17 @@ class OPTDecoder(OPTPreTrainedModel): self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size + prefix = prefix + "." if prefix else "" + self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.decoder.embed_tokens", weights=weights + prefix=f"{prefix}decoder.embed_tokens", weights=weights ) self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) if config.word_embed_proj_dim != config.hidden_size: self.project_out = FastLinear.load( config, - prefix=f"{prefix}.decoder.project_out", + prefix=f"{prefix}decoder.project_out", weights=weights, bias=False, ) @@ -455,7 +459,7 @@ class OPTDecoder(OPTPreTrainedModel): if config.word_embed_proj_dim != config.hidden_size: self.project_in = FastLinear.load( config, - prefix=f"{prefix}.decoder.project_in", + prefix=f"{prefix}decoder.project_in", weights=weights, bias=False, ) @@ -467,7 +471,7 @@ class OPTDecoder(OPTPreTrainedModel): # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS + prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS ) else: self.final_layer_norm = None @@ -752,15 +756,12 @@ class OPTForCausalLM(OPTPreTrainedModel): def __init__(self, prefix, config, weights): super().__init__(config) - if not prefix: - prefix = "model" - else: - prefix = f"{prefix}.model" - self.model = OPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights + config, + prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens", + weights=weights, ) def forward(