feat(server): use float16 (#304)
This commit is contained in:
parent
68e9d6ab33
commit
745f596c88
|
@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM):
|
|||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -452,7 +452,7 @@ class CausalLM(Model):
|
|||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
|
|
@ -199,7 +199,7 @@ class GalacticaSharded(Galactica):
|
|||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
|
|||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -54,7 +54,7 @@ class OPTSharded(OPT):
|
|||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -17,7 +17,7 @@ class SantaCoder(CausalLM):
|
|||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
|
|
@ -506,7 +506,7 @@ class Seq2SeqLM(Model):
|
|||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
|
|
@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
self.master = rank == 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
Loading…
Reference in New Issue