finalize
This commit is contained in:
parent
d3e79144e6
commit
c836efcfdc
18
README.md
18
README.md
|
@ -7,7 +7,7 @@
|
|||
|
||||
![model_diff_1_50](https://user-images.githubusercontent.com/23423619/171610307-dab0cd8b-75da-4d4e-9f5a-5922072e2bb5.png)
|
||||
|
||||
**Schedulers**: Algorithm to sample noise schedule for both *training* and *inference*. Defines alpha and beta schedule, timesteps, etc..
|
||||
**Schedulers**: Algorithm to compute previous image according to alpha, beta schedule and to sample noise. Should be used for both *training* and *inference*.
|
||||
*Example: Gaussian DDPM, DDIM, PMLS, DEIN*
|
||||
|
||||
![sampling](https://user-images.githubusercontent.com/23423619/171608981-3ad05953-a684-4c82-89f8-62a459147a07.png)
|
||||
|
@ -18,13 +18,15 @@
|
|||
|
||||
![imagen](https://user-images.githubusercontent.com/23423619/171609001-c3f2c1c9-f597-4a16-9843-749bf3f9431c.png)
|
||||
|
||||
## 1. `diffusers` as a central modular diffusion and sampler library
|
||||
## Quickstart
|
||||
|
||||
### 1. `diffusers` as a central modular diffusion and sampler library
|
||||
|
||||
`diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases.
|
||||
It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case.
|
||||
Both models and schedulers should be load- and saveable from the Hub.
|
||||
|
||||
Example for [DDPM](https://arxiv.org/abs/2006.11239):
|
||||
**Example for [DDPM](https://arxiv.org/abs/2006.11239):**
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
@ -50,7 +52,7 @@ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_s
|
|||
residual = self.unet(image, t)
|
||||
|
||||
# predict previous mean of image x_t-1
|
||||
pred_prev_image = noise_scheduler.get_prev_image_step(residual, image, t)
|
||||
pred_prev_image = noise_scheduler.compute_prev_image_step(residual, image, t)
|
||||
|
||||
# optionally sample variance
|
||||
variance = 0
|
||||
|
@ -71,7 +73,7 @@ image_pil = PIL.Image.fromarray(image_processed[0])
|
|||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
Example for [DDIM](https://arxiv.org/abs/2010.02502):
|
||||
**Example for [DDIM](https://arxiv.org/abs/2010.02502):**
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
@ -99,7 +101,7 @@ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_ste
|
|||
residual = self.unet(image, inference_step_times[t])
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = noise_scheduler.get_prev_image_step(residual, image, t, num_inference_steps, eta)
|
||||
pred_prev_image = noise_scheduler.compute_prev_image_step(residual, image, t, num_inference_steps, eta)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
|
@ -120,10 +122,10 @@ image_pil = PIL.Image.fromarray(image_processed[0])
|
|||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
## 2. `diffusers` as a collection of most important Diffusion systems (GLIDE, Dalle, ...)
|
||||
### 2. `diffusers` as a collection of most important Diffusion systems (GLIDE, Dalle, ...)
|
||||
`models` directory in repository hosts the complete code necessary for running a diffusion system as well as to train it. A `DiffusionPipeline` class allows to easily run the diffusion model in inference:
|
||||
|
||||
Example:
|
||||
**Example image generation with DDPM**
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
|
|
@ -58,7 +58,7 @@ class DDIM(DiffusionPipeline):
|
|||
residual = self.unet(image, inference_step_times[t])
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.get_prev_image_step(residual, image, t, num_inference_steps, eta)
|
||||
pred_prev_image = self.noise_scheduler.compute_prev_image_step(residual, image, t, num_inference_steps, eta)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
|
|
|
@ -45,7 +45,7 @@ class DDPM(DiffusionPipeline):
|
|||
residual = self.unet(image, t)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.get_prev_image_step(residual, image, t)
|
||||
pred_prev_image = self.noise_scheduler.compute_prev_image_step(residual, image, t)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
|
|
|
@ -100,7 +100,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
|
|||
|
||||
return variance
|
||||
|
||||
def get_prev_image_step(self, residual, image, t, num_inference_steps, eta, output_pred_x_0=False):
|
||||
def compute_prev_image_step(self, residual, image, t, num_inference_steps, eta, output_pred_x_0=False):
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
|||
|
||||
return variance
|
||||
|
||||
def get_prev_image_step(self, residual, image, t, output_pred_x_0=False):
|
||||
def compute_prev_image_step(self, residual, image, t, output_pred_x_0=False):
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.get_alpha_prod(t)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
|
||||
|
|
Loading…
Reference in New Issue