Fix MPS cache cleanup

Importing torch does not import torch.mps so the call failed.
This commit is contained in:
Aarni Koskela 2023-07-10 21:18:34 +03:00
parent 7b833291b3
commit b85fc7187d
2 changed files with 17 additions and 2 deletions

View File

@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif has_mps() and hasattr(torch.mps, 'empty_cache'):
torch.mps.empty_cache()
if has_mps():
mac_specific.torch_mps_gc()
def enable_tf32():

View File

@ -1,8 +1,12 @@
import logging
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
log = logging.getLogger()
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
@ -19,9 +23,19 @@ def check_for_mps() -> bool:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps()
def torch_mps_gc() -> None:
try:
from torch.mps import empty_cache
empty_cache()
except Exception:
log.warning("MPS garbage collection failed", exc_info=True)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':