Update README.md
This commit is contained in:
parent
fe3137304b
commit
6259f2a5f2
56
README.md
56
README.md
|
@ -27,25 +27,51 @@ One should be able to save both models and samplers as well as load them from th
|
|||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import UNetModel, GaussianDDPMScheduler
|
||||
import torch
|
||||
from diffusers import UNetModel, GaussianDDPMScheduler
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
# 1. Load model
|
||||
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(6694729458485568)
|
||||
|
||||
# 2. Do one denoising step with model
|
||||
batch_size, num_channels, height, width = 1, 3, 32, 32
|
||||
dummy_noise = torch.ones((batch_size, num_channels, height, width))
|
||||
time_step = torch.tensor([10])
|
||||
image = unet(dummy_noise, time_step)
|
||||
# 1. Load models
|
||||
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
|
||||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||
|
||||
# 3. Load sampler
|
||||
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
|
||||
# 2. Sample gaussian noise
|
||||
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
||||
|
||||
# 4. Sample image from sampler passing the model
|
||||
image = sampler.sample(model, batch_size=1)
|
||||
# 3. Denoise
|
||||
for t in reversed(range(len(scheduler))):
|
||||
# i) define coefficients for time step t
|
||||
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
||||
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
|
||||
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||
|
||||
print(image)
|
||||
# ii) predict noise residual
|
||||
with torch.no_grad():
|
||||
noise_residual = model(image, t)
|
||||
|
||||
# iii) compute predicted image from residual
|
||||
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
||||
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
prev_image = clip_coeff * pred_mean + image_coeff * image
|
||||
|
||||
# iv) sample variance
|
||||
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
image = sampled_prev_image
|
||||
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
## 2. `diffusers` as a collection of most import Diffusion models (GLIDE, Dalle, ...)
|
||||
|
@ -117,8 +143,8 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|||
│ ├── models
|
||||
│ │ └── unet.py
|
||||
│ ├── processors
|
||||
│ └── samplers
|
||||
│ ├── gaussian.py
|
||||
│ └── schedulers
|
||||
│ ├── gaussian_ddpm.py
|
||||
├── tests
|
||||
│ └── test_modeling_utils.py
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue