Fix wrong mps selection below MasOS 12.3
This commit is contained in:
parent
7ba3923d5b
commit
76ab31e188
|
@ -3,8 +3,15 @@ import contextlib
|
|||
import torch
|
||||
from modules import errors
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
has_mps = getattr(torch, 'has_mps', False)
|
||||
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def has_mps() -> bool:
|
||||
if getattr(torch, 'has_mps', False): return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
|
@ -25,7 +32,7 @@ def get_optimal_device():
|
|||
else:
|
||||
return torch.device("cuda")
|
||||
|
||||
if has_mps:
|
||||
if has_mps():
|
||||
return torch.device("mps")
|
||||
|
||||
return cpu
|
||||
|
|
Loading…
Reference in New Issue