add fast test for ldm
This commit is contained in:
parent
17bf65e186
commit
6921393ae2
|
@ -34,6 +34,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
# unet_glide.py
|
# unet_glide.py
|
||||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -198,7 +198,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||||
if not isinstance(spk, type(None)):
|
if not isinstance(spk, type(None)):
|
||||||
s = self.spk_mlp(spk)
|
s = self.spk_mlp(spk)
|
||||||
|
|
||||||
|
|
||||||
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
|
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
|
||||||
t = self.mlp(t)
|
t = self.mlp(t)
|
||||||
|
|
||||||
|
|
|
@ -694,6 +694,21 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_ldm_text2img_fast(self):
|
||||||
|
model_id = "fusing/latent-diffusion-text2im-large"
|
||||||
|
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = ldm([prompt], generator=generator, num_inference_steps=20)
|
||||||
|
|
||||||
|
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||||
|
|
||||||
|
assert image.shape == (1, 3, 256, 256)
|
||||||
|
expected_slice = torch.rensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||||
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_glide_text2img(self):
|
def test_glide_text2img(self):
|
||||||
model_id = "fusing/glide-base"
|
model_id = "fusing/glide-base"
|
||||||
|
|
Loading…
Reference in New Issue