fix(server): Fixing T5 in case the names are mixed up. (#475)

This commit is contained in:
Nicolas Patry 2023-06-20 18:03:36 +02:00 committed by GitHub
parent 53aa9194c8
commit c9c65ab323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 1 deletions

View File

@ -1001,7 +1001,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
super().__init__(config)
self.model_dim = config.d_model
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False