Fix device on save/load tests

This commit is contained in:
Patrick von Platen 2022-11-09 22:18:14 +00:00
parent 7d0c272939
commit 187de44352
1 changed files with 22 additions and 6 deletions

View File

@ -102,8 +102,12 @@ class DownloadTests(unittest.TestCase):
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
@ -124,8 +128,14 @@ class DownloadTests(unittest.TestCase):
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
@ -144,8 +154,14 @@ class DownloadTests(unittest.TestCase):
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3