parent
f063ebde10
commit
db4efbf4bc
|
@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model_dim = config.d_model
|
self.model_dim = config.d_model
|
||||||
|
|
||||||
try:
|
|
||||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||||
except RuntimeError:
|
|
||||||
self.shared = TensorParallelEmbedding(
|
|
||||||
prefix="encoder.embed_tokens", weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_config = copy.deepcopy(config)
|
encoder_config = copy.deepcopy(config)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
|
|
|
@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
filenames,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases={
|
||||||
|
"shared.weight": [
|
||||||
|
"encoder.embed_tokens.weight",
|
||||||
|
"decoder.embed_tokens.weight",
|
||||||
|
]
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
model = T5ForConditionalGeneration(config, weights)
|
model = T5ForConditionalGeneration(config, weights)
|
||||||
|
|
Loading…
Reference in New Issue