Add MPS manual cast

This commit is contained in:
KohakuBlueleaf 2023-10-28 16:52:35 +08:00
parent d4d3134f6d
commit ddc2a3499b
1 changed files with 5 additions and 1 deletions

View File

@ -121,6 +121,8 @@ def manual_autocast():
def manual_cast_forward(self, *args, **kwargs): def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype org_dtype = next(self.parameters()).dtype
self.to(dtype) self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
result = self.org_forward(*args, **kwargs) result = self.org_forward(*args, **kwargs)
self.to(org_dtype) self.to(org_dtype)
return result return result
@ -136,7 +138,6 @@ def manual_autocast():
def autocast(disable=False): def autocast(disable=False):
print(fp8, dtype, shared.cmd_opts.precision, device)
if disable: if disable:
return contextlib.nullcontext() return contextlib.nullcontext()
@ -146,6 +147,9 @@ def autocast(disable=False):
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_autocast() return manual_autocast()
if has_mps() and shared.cmd_opts.precision != "full":
return manual_autocast()
if dtype == torch.float32 or shared.cmd_opts.precision == "full": if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext() return contextlib.nullcontext()