Add FP32 fallback support on torch.nn.functional.interpolate
This tries to execute interpolate with FP32 if it failed. Background is that on some environment such as Mx chip MacOS devices, we get error as follows: ``` "torch/nn/functional.py", line 3931, in interpolate return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half' ``` In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it. Note that the ```upsample_nearest2d``` is called from ```torch.nn.functional.interpolate```. And the fallback for torch.nn.functional.interpolate is necessary at ```modules/sd_vae_approx.py``` 's ```VAEApprox.forward``` ```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py``` 's ```Upsample.forward```
This commit is contained in:
parent
39eae9f009
commit
a0096c5897
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from typing import Optional, List
|
||||
from torch import Tensor
|
||||
import platform
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
@ -51,6 +53,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
||||
try:
|
||||
return orig_func(*args, **kwargs)
|
||||
except RuntimeError as e:
|
||||
if "not implemented for" in str(e) and "Half" in str(e):
|
||||
input_tensor = args[0]
|
||||
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
||||
else:
|
||||
print(f"An unexpected RuntimeError occurred: {str(e)}")
|
||||
|
||||
if has_mps:
|
||||
if platform.mac_ver()[0].startswith("13.2."):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
|
@ -77,6 +90,9 @@ if has_mps:
|
|||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||
|
||||
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||
if platform.processor() == 'i386':
|
||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||
|
|
Loading…
Reference in New Issue