some renaming
This commit is contained in:
parent
9d32a26579
commit
d3e79144e6
114
README.md
114
README.md
|
@ -22,9 +22,9 @@
|
|||
|
||||
`diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases.
|
||||
It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case.
|
||||
Both models and scredulers should be load- and saveable from the Hub.
|
||||
Both models and schedulers should be load- and saveable from the Hub.
|
||||
|
||||
Example:
|
||||
Example for [DDPM](https://arxiv.org/abs/2006.11239):
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
@ -32,65 +32,91 @@ from diffusers import UNetModel, GaussianDDPMScheduler
|
|||
import PIL
|
||||
import numpy as np
|
||||
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(6694729458485568)
|
||||
generator = torch.manual_seed(0)
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 1. Load models
|
||||
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
|
||||
noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
|
||||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
||||
image = noise_scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
||||
|
||||
# 3. Denoise
|
||||
for t in reversed(range(len(scheduler))):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
pred_noise_t = self.unet(image, t)
|
||||
num_prediction_steps = len(noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# predict noise residual
|
||||
with torch.no_grad():
|
||||
residual = self.unet(image, t)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = scheduler.get_alpha_prod(t)
|
||||
alpha_prod_t_prev = scheduler.get_alpha_prod(t - 1)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
# predict previous mean of image x_t-1
|
||||
pred_prev_image = noise_scheduler.get_prev_image_step(residual, image, t)
|
||||
|
||||
# 3. compute predicted image from residual
|
||||
# First: compute predicted original image from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
|
||||
# optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
variance = noise_scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# Second: Clip "predicted x_0"
|
||||
pred_original_image = torch.clamp(pred_original_image, -1, 1)
|
||||
# set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# Third: Compute coefficients for pred_original_image x_0 and current image x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * scheduler.get_beta(t)) / beta_prod_t
|
||||
current_image_coeff = scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
|
||||
# Fourth: Compute predicted previous image µ_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
|
||||
|
||||
# 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous image
|
||||
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
|
||||
if t > 0:
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
|
||||
noise = scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
prev_image = pred_prev_image + variance * noise
|
||||
else:
|
||||
prev_image = pred_prev_image
|
||||
|
||||
# 6. Set current image to prev_image: x_t -> x_t-1
|
||||
image = prev_image
|
||||
|
||||
# process image to PIL
|
||||
# 5. process image to PIL
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
# 6. save image
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
Example for [DDIM](https://arxiv.org/abs/2010.02502):
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import UNetModel, DDIMScheduler
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 1. Load models
|
||||
noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq")
|
||||
model = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
image = noise_scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
||||
|
||||
# 3. Denoise
|
||||
num_inference_steps = 50
|
||||
eta = 0.0 # <- deterministic sampling
|
||||
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
residual = self.unet(image, inference_step_times[t])
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = noise_scheduler.get_prev_image_step(residual, image, t, num_inference_steps, eta)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
variance = noise_scheduler.get_variance(t).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# 5. process image to PIL
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# 6. save image
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ class DDIM(DiffusionPipeline):
|
|||
residual = self.unet(image, inference_step_times[t])
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.predict_prev_image_step(residual, image, t, num_inference_steps, eta)
|
||||
pred_prev_image = self.noise_scheduler.get_prev_image_step(residual, image, t, num_inference_steps, eta)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
|
@ -69,44 +69,4 @@ class DDIM(DiffusionPipeline):
|
|||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# 2. get actual t and t-1
|
||||
# train_step = inference_step_times[t]
|
||||
# prev_train_step = inference_step_times[t - 1] if t > 0 else -1
|
||||
#
|
||||
# 3. compute alphas, betas
|
||||
# alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
|
||||
# alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
|
||||
# beta_prod_t = 1 - alpha_prod_t
|
||||
# beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
#
|
||||
# 4. Compute predicted previous image from predicted noise
|
||||
# First: compute predicted original image from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
|
||||
#
|
||||
# Second: Clip "predicted x_0"
|
||||
# pred_original_image = torch.clamp(pred_original_image, -1, 1)
|
||||
#
|
||||
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
# std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
|
||||
# std_dev_t = eta * std_dev_t
|
||||
#
|
||||
# Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
|
||||
#
|
||||
# Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
|
||||
#
|
||||
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
|
||||
# Note: eta = 1.0 essentially corresponds to DDPM
|
||||
# if eta > 0.0:
|
||||
# noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
# prev_image = pred_prev_image + std_dev_t * noise
|
||||
# else:
|
||||
# prev_image = pred_prev_image
|
||||
#
|
||||
# 6. Set current image to prev_image: x_t -> x_t-1
|
||||
# image = prev_image
|
||||
|
||||
return image
|
||||
|
|
|
@ -39,20 +39,19 @@ class DDPM(DiffusionPipeline):
|
|||
)
|
||||
|
||||
num_prediction_steps = len(self.noise_scheduler)
|
||||
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
residual = self.unet(image, t)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.predict_prev_image_step(residual, image, t)
|
||||
pred_prev_image = self.noise_scheduler.get_prev_image_step(residual, image, t)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
||||
variance = self.noise_scheduler.get_variance(t) * noise
|
||||
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
|
|
@ -100,7 +100,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
|
|||
|
||||
return variance
|
||||
|
||||
def predict_prev_image_step(self, residual, image, t, num_inference_steps, eta, output_pred_x_0=False):
|
||||
def get_prev_image_step(self, residual, image, t, num_inference_steps, eta, output_pred_x_0=False):
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
|||
)
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.clip_image = clip_predicted_image
|
||||
self.variance_type = variance_type
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
|
@ -97,11 +98,17 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
|||
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous image
|
||||
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t).sqrt()
|
||||
variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t))
|
||||
|
||||
# hacks - were probs added for training stability
|
||||
if self.variance_type == "fixed_small":
|
||||
variance = variance.clamp(min=1e-20)
|
||||
elif self.variance_type == "fixed_large":
|
||||
variance = self.get_beta(t)
|
||||
|
||||
return variance
|
||||
|
||||
def predict_prev_image_step(self, residual, image, t, output_pred_x_0=False):
|
||||
def get_prev_image_step(self, residual, image, t, output_pred_x_0=False):
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.get_alpha_prod(t)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
|
||||
|
|
Loading…
Reference in New Issue