adapt test
This commit is contained in:
parent
1c14ce9509
commit
43bbc78123
|
@ -36,6 +36,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
for t in tqdm.tqdm(self.scheduler.timesteps):
|
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||||
|
with torch.no_grad():
|
||||||
model_output = self.unet(image, t)
|
model_output = self.unet(image, t)
|
||||||
|
|
||||||
if isinstance(model_output, dict):
|
if isinstance(model_output, dict):
|
||||||
|
@ -46,5 +47,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||||
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
|
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
|
||||||
|
|
||||||
# decode image with vae
|
# decode image with vae
|
||||||
|
with torch.no_grad():
|
||||||
image = self.vqvae.decode(image)
|
image = self.vqvae.decode(image)
|
||||||
return {"sample": image}
|
return {"sample": image}
|
||||||
|
|
|
@ -1070,7 +1070,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_ldm_uncond(self):
|
def test_ldm_uncond(self):
|
||||||
ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True)
|
# ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True)
|
||||||
|
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/latent-diffusion-celeba-256")
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = ldm(generator=generator, num_inference_steps=5)["sample"]
|
image = ldm(generator=generator, num_inference_steps=5)["sample"]
|
||||||
|
|
Loading…
Reference in New Issue