From db4efbf4bcb851a185879099ac01fdc61e34a062 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 10:01:42 +0200 Subject: [PATCH] fix(server): T5 weights names. (#582) Fixes #541 --- .../models/custom_modeling/t5_modeling.py | 7 +------ server/text_generation_server/models/t5.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 1ea8280..5779a27 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): super().__init__(config) self.model_dim = config.d_model - try: - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) - except RuntimeError: - self.shared = TensorParallelEmbedding( - prefix="encoder.embed_tokens", weights=weights - ) + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 1b7073a..133aafd 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases={ + "shared.weight": [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + }, ) model = T5ForConditionalGeneration(config, weights)