Remove warning about half precision on MPS (#1163)

Remove warning about half precision on MPS.
This commit is contained in:
Pedro Cuenca 2022-11-07 12:27:17 +01:00 committed by GitHub
parent b4a1ed8544
commit e86a280c45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 6 deletions

View File

@ -209,13 +209,13 @@ class DiffusionPipeline(ConfigMixin):
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
" `float16` operations on those devices in PyTorch. Please remove the"
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` as running them will fail. Please make"
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
module.to(torch_device)
return self