fix tests

This commit is contained in:
Patrick von Platen 2022-06-28 17:36:56 +00:00
parent e372767c4d
commit 79db3eb6ca
1 changed files with 5 additions and 4 deletions

View File

@ -845,7 +845,6 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ddpm_cifar10(self):
generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id)
@ -853,12 +852,14 @@ class PipelineTesterMixin(unittest.TestCase):
noise_scheduler = noise_scheduler.set_format("pt")
ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([0.2249, 0.3375, 0.2359, 0.0929, 0.3439, 0.3156, 0.1937, 0.3585, 0.1761])
expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
@ -883,20 +884,20 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_pndm_cifar10(self):
generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id)
noise_scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler)
generator = torch.manual_seed(0)
image = pndm(generator=generator)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.7925, -0.7902, -0.7789, -0.7796, -0.8000, -0.7596, -0.6852, -0.7125, -0.7494]
[-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2