fix(server): Fixing T5 in case the names are mixed up. (#475)
This commit is contained in:
parent
53aa9194c8
commit
c9c65ab323
|
@ -1001,7 +1001,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
super().__init__(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||
try:
|
||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||
except RuntimeError:
|
||||
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.is_decoder = False
|
||||
|
|
Loading…
Reference in New Issue