add fast test for ldm

This commit is contained in:
patil-suraj 2022-06-27 11:42:52 +02:00
parent 17bf65e186
commit 6921393ae2
3 changed files with 16 additions and 1 deletions

View File

@ -34,6 +34,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb return emb
# unet_glide.py # unet_glide.py
def timestep_embedding(timesteps, dim, max_period=10000): def timestep_embedding(timesteps, dim, max_period=10000):
""" """

View File

@ -198,7 +198,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if not isinstance(spk, type(None)): if not isinstance(spk, type(None)):
s = self.spk_mlp(spk) s = self.spk_mlp(spk)
t = self.time_pos_emb(timesteps, scale=self.pe_scale) t = self.time_pos_emb(timesteps, scale=self.pe_scale)
t = self.mlp(t) t = self.mlp(t)

View File

@ -694,6 +694,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_ldm_text2img_fast(self):
model_id = "fusing/latent-diffusion-text2im-large"
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=20)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.rensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
def test_glide_text2img(self): def test_glide_text2img(self):
model_id = "fusing/glide-base" model_id = "fusing/glide-base"