parent
91d9beec90
commit
4f4c9c1665
|
@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
Loading…
Reference in New Issue