parent
f063ebde10
commit
db4efbf4bc
|
@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
super().__init__(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
try:
|
||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||
except RuntimeError:
|
||||
self.shared = TensorParallelEmbedding(
|
||||
prefix="encoder.embed_tokens", weights=weights
|
||||
)
|
||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.is_decoder = False
|
||||
|
|
|
@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={
|
||||
"shared.weight": [
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
model = T5ForConditionalGeneration(config, weights)
|
||||
|
|
Loading…
Reference in New Issue