Better naming
This commit is contained in:
parent
f383af2729
commit
043d2edcf6
|
@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def manual_autocast():
|
def manual_cast():
|
||||||
for module_type in patch_module_list:
|
for module_type in patch_module_list:
|
||||||
org_forward = module_type.forward
|
org_forward = module_type.forward
|
||||||
module_type.forward = manual_cast_forward
|
module_type.forward = manual_cast_forward
|
||||||
|
@ -148,10 +148,10 @@ def autocast(disable=False):
|
||||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||||
|
|
||||||
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
|
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
|
||||||
return manual_autocast()
|
return manual_cast()
|
||||||
|
|
||||||
if has_mps() and shared.cmd_opts.precision != "full":
|
if has_mps() and shared.cmd_opts.precision != "full":
|
||||||
return manual_autocast()
|
return manual_cast()
|
||||||
|
|
||||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
Loading…
Reference in New Issue