diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 775ab689..4559d713 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -102,8 +102,12 @@ class DownloadTests(unittest.TestCase): pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") pipe_2 = pipe_2.to(torch_device) - generator_2 = generator.manual_seed(0) - out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images assert np.max(np.abs(out - out_2)) < 1e-3 @@ -124,8 +128,14 @@ class DownloadTests(unittest.TestCase): pipe.save_pretrained(tmpdirname) pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None) pipe_2 = pipe_2.to(torch_device) - generator_2 = generator.manual_seed(0) - out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images assert np.max(np.abs(out - out_2)) < 1e-3 @@ -144,8 +154,14 @@ class DownloadTests(unittest.TestCase): pipe.save_pretrained(tmpdirname) pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname) pipe_2 = pipe_2.to(torch_device) - generator_2 = generator.manual_seed(0) - out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images assert np.max(np.abs(out - out_2)) < 1e-3