add test for ldm uncond

This commit is contained in:
patil-suraj 2022-06-29 15:25:51 +02:00
parent 65788e46ed
commit 859ffea2b1
1 changed files with 16 additions and 3 deletions

View File

@ -34,6 +34,7 @@ from diffusers import (
GradTTSPipeline, GradTTSPipeline,
GradTTSScheduler, GradTTSScheduler,
LatentDiffusionPipeline, LatentDiffusionPipeline,
LatentDiffusionUncondPipeline,
NCSNpp, NCSNpp,
PNDMPipeline, PNDMPipeline,
PNDMScheduler, PNDMScheduler,
@ -46,7 +47,6 @@ from diffusers import (
UNetLDMModel, UNetLDMModel,
UNetModel, UNetModel,
VQModel, VQModel,
AutoencoderKL,
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
@ -915,7 +915,7 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
"out_ch": 3, "out_ch": 3,
"resolution": 32, "resolution": 32,
"z_channels": 4, "z_channels": 4,
"attn_resolutions": [] "attn_resolutions": [],
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
@ -1151,6 +1151,19 @@ class PipelineTesterMixin(unittest.TestCase):
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
@slow
def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256")
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4) model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12) noise_scheduler = DDPMScheduler(timesteps=12)