correct readme
This commit is contained in:
parent
7764669c54
commit
20d9178237
18
README.md
18
README.md
|
@ -48,9 +48,13 @@ noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
|
|||
unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator)
|
||||
image = torch.randn(
|
||||
(1, unet.in_channels, unet.resolution, unet.resolution)
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
# 3. Denoise
|
||||
# 3. Denoise
|
||||
num_prediction_steps = len(noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# predict noise residual
|
||||
|
@ -63,7 +67,7 @@ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_s
|
|||
# optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = noise_scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# set current image to prev_image: x_t -> x_t-1
|
||||
|
@ -96,7 +100,11 @@ noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq")
|
|||
unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator)
|
||||
image = torch.randn(
|
||||
(1, unet.in_channels, unet.resolution, unet.resolution)
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
# 3. Denoise
|
||||
num_inference_steps = 50
|
||||
|
@ -114,7 +122,7 @@ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_ste
|
|||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = noise_scheduler.get_variance(t).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
|
|
Loading…
Reference in New Issue