From acb2faaefad467f60a3f338021f9ef5d53244c06 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 16 Jun 2022 10:22:55 +0200 Subject: [PATCH] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a8a2ca80..0a8d62cd 100644 --- a/README.md +++ b/README.md @@ -137,8 +137,8 @@ unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device) # 2. Sample gaussian noise image = torch.randn( - (1, unet.in_channels, unet.resolution, unet.resolution), - generator=generator, + (1, unet.in_channels, unet.resolution, unet.resolution), + generator=generator, ) image = image.to(torch_device) @@ -147,10 +147,10 @@ num_inference_steps = 50 eta = 0.0 # <- deterministic sampling for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): - # 1. predict noise residual + # 1. predict noise residual orig_t = noise_scheduler.get_orig_t(t, num_inference_steps) with torch.no_grad(): - residual = unet(image, orig_t) + residual = unet(image, orig_t) # 2. predict previous mean of image x_t-1 pred_prev_image = noise_scheduler.step(residual, image, t, num_inference_steps, eta)