Raise an error when moving an fp16 pipeline to CPU (#749)
* Raise an error when moving an fp16 pipeline to CPU * Raise an error when moving an fp16 pipeline to CPU * style * Update src/diffusers/pipeline_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/pipeline_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Improve the message * cuda * Update tests/test_pipelines.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
3383f77441
commit
6c64741933
|
@ -166,6 +166,12 @@ class DiffusionPipeline(ConfigMixin):
|
|||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
|
||||
raise ValueError(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` "
|
||||
"due to the lack of support for `float16` operations on those devices in PyTorch. "
|
||||
"Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device."
|
||||
)
|
||||
module.to(torch_device)
|
||||
return self
|
||||
|
||||
|
|
|
@ -188,6 +188,17 @@ class PipelineFastTests(unittest.TestCase):
|
|||
|
||||
return extract
|
||||
|
||||
def test_pipeline_fp16_cpu_error(self):
|
||||
model = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
pipe = DDIMPipeline(model.half(), scheduler)
|
||||
|
||||
if str(torch_device) in ["cpu", "mps"]:
|
||||
self.assertRaises(ValueError, pipe.to, torch_device)
|
||||
else:
|
||||
# moving the pipeline to GPU should work
|
||||
pipe.to(torch_device)
|
||||
|
||||
def test_ddim(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
|
|
Loading…
Reference in New Issue