correct readme

This commit is contained in:
Patrick von Platen 2022-06-12 22:14:03 +00:00
parent 7764669c54
commit 20d9178237
1 changed files with 13 additions and 5 deletions

View File

@ -48,9 +48,13 @@ noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator)
image = torch.randn(
(1, unet.in_channels, unet.resolution, unet.resolution)
generator=generator,
)
image = image.to(torch_device)
# 3. Denoise
# 3. Denoise
num_prediction_steps = len(noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# predict noise residual
@ -63,7 +67,7 @@ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_s
# optionally sample variance
variance = 0
if t > 0:
noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = noise_scheduler.get_variance(t).sqrt() * noise
# set current image to prev_image: x_t -> x_t-1
@ -96,7 +100,11 @@ noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq")
unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device)
# 2. Sample gaussian noise
image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator)
image = torch.randn(
(1, unet.in_channels, unet.resolution, unet.resolution)
generator=generator,
)
image = image.to(torch_device)
# 3. Denoise
num_inference_steps = 50
@ -114,7 +122,7 @@ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_ste
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = noise_scheduler.get_variance(t).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1