From 4f4c9c16655841ea3ccce69ff46b9778a097cb70 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 23 May 2023 12:15:54 +0200 Subject: [PATCH] fix(server): t5 cannot run in f16 (#356) Fix #349 --- server/text_generation_server/models/t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index b1ba2432..2fd67574 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32