Fix device on save/load tests
This commit is contained in:
parent
7d0c272939
commit
187de44352
|
@ -102,8 +102,12 @@ class DownloadTests(unittest.TestCase):
|
|||
|
||||
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||
pipe_2 = pipe_2.to(torch_device)
|
||||
generator_2 = generator.manual_seed(0)
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||
|
||||
|
@ -124,8 +128,14 @@ class DownloadTests(unittest.TestCase):
|
|||
pipe.save_pretrained(tmpdirname)
|
||||
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
|
||||
pipe_2 = pipe_2.to(torch_device)
|
||||
generator_2 = generator.manual_seed(0)
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||
|
||||
|
@ -144,8 +154,14 @@ class DownloadTests(unittest.TestCase):
|
|||
pipe.save_pretrained(tmpdirname)
|
||||
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
pipe_2 = pipe_2.to(torch_device)
|
||||
generator_2 = generator.manual_seed(0)
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||
|
||||
|
|
Loading…
Reference in New Issue