fix(server): t5 cannot run in f16 (#356)

Fix #349
This commit is contained in:
OlivierDehaene 2023-05-23 12:15:54 +02:00 committed by GitHub
parent 91d9beec90
commit 4f4c9c1665
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32