fix(server): T5 weights names. (#582)

Fixes #541
This commit is contained in:
Nicolas Patry 2023-07-12 10:01:42 +02:00 committed by GitHub
parent f063ebde10
commit db4efbf4bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 7 deletions

View File

@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
super().__init__(config)
self.model_dim = config.d_model
try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False

View File

@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
)
model = T5ForConditionalGeneration(config, weights)