fix tests
This commit is contained in:
parent
e372767c4d
commit
79db3eb6ca
|
@ -845,7 +845,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_ddpm_cifar10(self):
|
||||
generator = torch.manual_seed(0)
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
unet = UNetModel.from_pretrained(model_id)
|
||||
|
@ -853,12 +852,14 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor([0.2249, 0.3375, 0.2359, 0.0929, 0.3439, 0.3156, 0.1937, 0.3585, 0.1761])
|
||||
expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
|
@ -883,20 +884,20 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_pndm_cifar10(self):
|
||||
generator = torch.manual_seed(0)
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
unet = UNetModel.from_pretrained(model_id)
|
||||
noise_scheduler = PNDMScheduler(tensor_format="pt")
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.7925, -0.7902, -0.7789, -0.7796, -0.8000, -0.7596, -0.6852, -0.7125, -0.7494]
|
||||
[-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
|
|
Loading…
Reference in New Issue