add test for ldm uncond
This commit is contained in:
parent
65788e46ed
commit
859ffea2b1
|
@ -34,6 +34,7 @@ from diffusers import (
|
||||||
GradTTSPipeline,
|
GradTTSPipeline,
|
||||||
GradTTSScheduler,
|
GradTTSScheduler,
|
||||||
LatentDiffusionPipeline,
|
LatentDiffusionPipeline,
|
||||||
|
LatentDiffusionUncondPipeline,
|
||||||
NCSNpp,
|
NCSNpp,
|
||||||
PNDMPipeline,
|
PNDMPipeline,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
|
@ -46,7 +47,6 @@ from diffusers import (
|
||||||
UNetLDMModel,
|
UNetLDMModel,
|
||||||
UNetModel,
|
UNetModel,
|
||||||
VQModel,
|
VQModel,
|
||||||
AutoencoderKL,
|
|
||||||
)
|
)
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
|
@ -915,7 +915,7 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||||
"out_ch": 3,
|
"out_ch": 3,
|
||||||
"resolution": 32,
|
"resolution": 32,
|
||||||
"z_channels": 4,
|
"z_channels": 4,
|
||||||
"attn_resolutions": []
|
"attn_resolutions": [],
|
||||||
}
|
}
|
||||||
inputs_dict = self.dummy_input
|
inputs_dict = self.dummy_input
|
||||||
return init_dict, inputs_dict
|
return init_dict, inputs_dict
|
||||||
|
@ -1151,6 +1151,19 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
|
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
|
||||||
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
|
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_ldm_uncond(self):
|
||||||
|
ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256")
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = ldm(generator=generator, num_inference_steps=5)
|
||||||
|
|
||||||
|
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||||
|
|
||||||
|
assert image.shape == (1, 3, 256, 256)
|
||||||
|
expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293])
|
||||||
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
def test_module_from_pipeline(self):
|
def test_module_from_pipeline(self):
|
||||||
model = DiffWave(num_res_layers=4)
|
model = DiffWave(num_res_layers=4)
|
||||||
noise_scheduler = DDPMScheduler(timesteps=12)
|
noise_scheduler = DDPMScheduler(timesteps=12)
|
||||||
|
|
Loading…
Reference in New Issue