clip => clipped

This commit is contained in:
patil-suraj 2022-06-07 16:34:44 +02:00
parent 5aea843a41
commit f39020bd8a
3 changed files with 20 additions and 20 deletions

View File

@ -46,10 +46,10 @@ image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.re
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
@ -57,9 +57,9 @@ for t in reversed(range(len(scheduler))):
# iii) compute predicted image from residual # iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1) pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)

View File

@ -36,10 +36,10 @@ class DDPM(DiffusionPipeline):
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t # i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
@ -47,9 +47,9 @@ class DDPM(DiffusionPipeline):
# iii) compute predicted image from residual # iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1) pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)

View File

@ -126,10 +126,10 @@ class SamplerTesterMixin(unittest.TestCase):
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
@ -137,9 +137,9 @@ class SamplerTesterMixin(unittest.TestCase):
# iii) compute predicted image from residual # iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1) pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
@ -176,10 +176,10 @@ class SamplerTesterMixin(unittest.TestCase):
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
@ -187,9 +187,9 @@ class SamplerTesterMixin(unittest.TestCase):
# iii) compute predicted image from residual # iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1) pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)