Fix device on save/load tests
This commit is contained in:
parent
7d0c272939
commit
187de44352
|
@ -102,8 +102,12 @@ class DownloadTests(unittest.TestCase):
|
||||||
|
|
||||||
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||||
pipe_2 = pipe_2.to(torch_device)
|
pipe_2 = pipe_2.to(torch_device)
|
||||||
generator_2 = generator.manual_seed(0)
|
if torch_device == "mps":
|
||||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
# device type MPS is not supported for torch.Generator() api.
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||||
|
|
||||||
|
@ -124,8 +128,14 @@ class DownloadTests(unittest.TestCase):
|
||||||
pipe.save_pretrained(tmpdirname)
|
pipe.save_pretrained(tmpdirname)
|
||||||
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
|
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
|
||||||
pipe_2 = pipe_2.to(torch_device)
|
pipe_2 = pipe_2.to(torch_device)
|
||||||
generator_2 = generator.manual_seed(0)
|
|
||||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
if torch_device == "mps":
|
||||||
|
# device type MPS is not supported for torch.Generator() api.
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
|
||||||
|
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||||
|
|
||||||
|
@ -144,8 +154,14 @@ class DownloadTests(unittest.TestCase):
|
||||||
pipe.save_pretrained(tmpdirname)
|
pipe.save_pretrained(tmpdirname)
|
||||||
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||||
pipe_2 = pipe_2.to(torch_device)
|
pipe_2 = pipe_2.to(torch_device)
|
||||||
generator_2 = generator.manual_seed(0)
|
|
||||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
if torch_device == "mps":
|
||||||
|
# device type MPS is not supported for torch.Generator() api.
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
|
||||||
|
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue