diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 25ec8cb8..7c50644a 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM): self.master = rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 610dc4e2..f838fc5c 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -452,7 +452,7 @@ class CausalLM(Model): ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 57df0bab..f6098e6c 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -199,7 +199,7 @@ class GalacticaSharded(Galactica): self.master = rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index c71bf366..cd36dba0 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM): self.master = rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index cdc32c56..44f15df3 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -54,7 +54,7 @@ class OPTSharded(OPT): self.master = rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index a7b09a82..5a142676 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -17,7 +17,7 @@ class SantaCoder(CausalLM): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d4a0ddcc..77912cff 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -506,7 +506,7 @@ class Seq2SeqLM(Model): ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 5691c005..4cfdea9e 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM): self.master = rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32