Remove warning about half precision on MPS (#1163)
Remove warning about half precision on MPS.
This commit is contained in:
parent
b4a1ed8544
commit
e86a280c45
|
@ -209,13 +209,13 @@ class DiffusionPipeline(ConfigMixin):
|
|||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
|
||||
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
|
||||
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
|
||||
" `float16` operations on those devices in PyTorch. Please remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
module.to(torch_device)
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue