From 187de44352ce23acf00a9204a05a8a308aab7003 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 9 Nov 2022 22:18:14 +0000 Subject: [PATCH] Fix device on save/load tests --- tests/test_pipelines.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 775ab689..4559d713 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -102,8 +102,12 @@ class DownloadTests(unittest.TestCase): pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") pipe_2 = pipe_2.to(torch_device) - generator_2 = generator.manual_seed(0) - out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images assert np.max(np.abs(out - out_2)) < 1e-3 @@ -124,8 +128,14 @@ class DownloadTests(unittest.TestCase): pipe.save_pretrained(tmpdirname) pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None) pipe_2 = pipe_2.to(torch_device) - generator_2 = generator.manual_seed(0) - out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images assert np.max(np.abs(out - out_2)) < 1e-3 @@ -144,8 +154,14 @@ class DownloadTests(unittest.TestCase): pipe.save_pretrained(tmpdirname) pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname) pipe_2 = pipe_2.to(torch_device) - generator_2 = generator.manual_seed(0) - out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images assert np.max(np.abs(out - out_2)) < 1e-3