From 6c64741933c64df276b5ede21f62777dbe079cfd Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 6 Oct 2022 15:51:03 +0200 Subject: [PATCH] Raise an error when moving an fp16 pipeline to CPU (#749) * Raise an error when moving an fp16 pipeline to CPU * Raise an error when moving an fp16 pipeline to CPU * style * Update src/diffusers/pipeline_utils.py Co-authored-by: Pedro Cuenca * Update src/diffusers/pipeline_utils.py Co-authored-by: Suraj Patil * Improve the message * cuda * Update tests/test_pipelines.py Co-authored-by: Pedro Cuenca Co-authored-by: Suraj Patil --- src/diffusers/pipeline_utils.py | 6 ++++++ tests/test_pipelines.py | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 17560055..01ba4eef 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -166,6 +166,12 @@ class DiffusionPipeline(ConfigMixin): for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): + if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]: + raise ValueError( + "Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` " + "due to the lack of support for `float16` operations on those devices in PyTorch. " + "Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device." + ) module.to(torch_device) return self diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e6cc37ad..69301e97 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -188,6 +188,17 @@ class PipelineFastTests(unittest.TestCase): return extract + def test_pipeline_fp16_cpu_error(self): + model = self.dummy_uncond_unet + scheduler = DDPMScheduler(num_train_timesteps=10) + pipe = DDIMPipeline(model.half(), scheduler) + + if str(torch_device) in ["cpu", "mps"]: + self.assertRaises(ValueError, pipe.to, torch_device) + else: + # moving the pipeline to GPU should work + pipe.to(torch_device) + def test_ddim(self): unet = self.dummy_uncond_unet scheduler = DDIMScheduler()