Raise an error when moving an fp16 pipeline to CPU (#749)

* Raise an error when moving an fp16 pipeline to CPU

* Raise an error when moving an fp16 pipeline to CPU

* style

* Update src/diffusers/pipeline_utils.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipeline_utils.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Improve the message

* cuda

* Update tests/test_pipelines.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Anton Lozhkov 2022-10-06 15:51:03 +02:00 committed by GitHub
parent 3383f77441
commit 6c64741933
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 0 deletions

View File

@ -166,6 +166,12 @@ class DiffusionPipeline(ConfigMixin):
for name in module_names.keys(): for name in module_names.keys():
module = getattr(self, name) module = getattr(self, name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
raise ValueError(
"Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` "
"due to the lack of support for `float16` operations on those devices in PyTorch. "
"Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device."
)
module.to(torch_device) module.to(torch_device)
return self return self

View File

@ -188,6 +188,17 @@ class PipelineFastTests(unittest.TestCase):
return extract return extract
def test_pipeline_fp16_cpu_error(self):
model = self.dummy_uncond_unet
scheduler = DDPMScheduler(num_train_timesteps=10)
pipe = DDIMPipeline(model.half(), scheduler)
if str(torch_device) in ["cpu", "mps"]:
self.assertRaises(ValueError, pipe.to, torch_device)
else:
# moving the pipeline to GPU should work
pipe.to(torch_device)
def test_ddim(self): def test_ddim(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDIMScheduler() scheduler = DDIMScheduler()