From 16b4509fa60ec03102b2452b41799dafccd35970 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 17 Dec 2022 03:21:19 -0500 Subject: [PATCH 1/2] Add numpy fix for MPS on PyTorch 1.12.1 When saving training results with torch.save(), an exception is thrown: "RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead." So for MPS, check if Tensor.requires_grad and detach() if necessary. --- modules/devices.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/devices.py b/modules/devices.py index f8cffae1a..800510b7c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -125,7 +125,16 @@ def layer_norm_fix(*args, **kwargs): return orig_layer_norm(*args, **kwargs) +# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 +orig_tensor_numpy = torch.Tensor.numpy +def numpy_fix(self, *args, **kwargs): + if self.requires_grad: + self = self.detach() + return orig_tensor_numpy(self, *args, **kwargs) + + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): torch.Tensor.to = tensor_to_fix torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix From cca16373def60bfc6d159a3c2dca91d0ba48112a Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 17 Dec 2022 03:24:54 -0500 Subject: [PATCH 2/2] Add attributes used by MPS --- modules/safe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/safe.py b/modules/safe.py index 10460ad00..7c89c4c22 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler): if module == 'collections' and name == 'OrderedDict': return getattr(collections, name) - if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: + if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: return getattr(torch._utils, name) - if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']: + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']: return getattr(torch, name) if module == 'torch.nn.modules.container' and name in ['ParameterDict']: return getattr(torch.nn.modules.container, name) - if module == 'numpy.core.multiarray' and name == 'scalar': - return numpy.core.multiarray.scalar - if module == 'numpy' and name == 'dtype': - return numpy.dtype + if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: + return getattr(numpy.core.multiarray, name) + if module == 'numpy' and name in ['dtype', 'ndarray']: + return getattr(numpy, name) if module == '_codecs' and name == 'encode': return encode if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':