feat(server): use float16 (#304)

This commit is contained in:
OlivierDehaene 2023-05-10 15:51:10 +02:00 committed by GitHub
parent 68e9d6ab33
commit 745f596c88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 8 additions and 8 deletions

View File

@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -452,7 +452,7 @@ class CausalLM(Model):
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -199,7 +199,7 @@ class GalacticaSharded(Galactica):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -54,7 +54,7 @@ class OPTSharded(OPT):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -17,7 +17,7 @@ class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -506,7 +506,7 @@ class Seq2SeqLM(Model):
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32