add configs for training unconditional/class-conditional ldms
This commit is contained in:
parent
f8b4a07105
commit
171cf29fb5
100
README.md
100
README.md
|
@ -55,18 +55,7 @@ bash scripts/download_first_stages.sh
|
|||
```
|
||||
|
||||
The first stage models can then be found in `models/first_stage_models/<model_spec>`
|
||||
### Training autoencoder models
|
||||
|
||||
Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
|
||||
Training can be started by running
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec> -t --gpus 0,
|
||||
```
|
||||
where `config_spec` is one of {`autoencoder_kl_8x8x64.yaml`(f=32, d=64), `autoencoder_kl_16x16x16.yaml`(f=16, d=16),
|
||||
`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
|
||||
|
||||
For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
|
||||
repository.
|
||||
|
||||
|
||||
## Pretrained LDMs
|
||||
|
@ -78,9 +67,10 @@ repository.
|
|||
| LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
|
||||
| ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
|
||||
| Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
|
||||
| OpenImages | Super-resolution | N/A | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
|
||||
| OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
|
||||
| OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
|
||||
| Landscapes (finetuned 512) | Semantic Image Synthesis | LDM-VQ-4 (100 DDIM steps, eta=1) | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | |
|
||||
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
|
||||
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
|
||||
|
||||
|
||||
### Get the models
|
||||
|
@ -116,10 +106,90 @@ python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inp
|
|||
`indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
|
||||
the examples provided in `data/inpainting_examples`.
|
||||
|
||||
|
||||
# Train your own LDMs
|
||||
|
||||
## Data preparation
|
||||
|
||||
### Faces
|
||||
For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq)
|
||||
repository.
|
||||
|
||||
### LSUN
|
||||
|
||||
The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
|
||||
We performed a custom split into training and validation images, and provide the corresponding filenames
|
||||
at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
|
||||
After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
|
||||
also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
|
||||
|
||||
### ImageNet
|
||||
The code will try to download (through [Academic
|
||||
Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
|
||||
is used. However, since ImageNet is quite large, this requires a lot of disk
|
||||
space and time. If you already have ImageNet on your disk, you can speed things
|
||||
up by putting the data into
|
||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
|
||||
`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
|
||||
of `train`/`validation`. It should have the following structure:
|
||||
|
||||
```
|
||||
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
|
||||
├── n01440764
|
||||
│ ├── n01440764_10026.JPEG
|
||||
│ ├── n01440764_10027.JPEG
|
||||
│ ├── ...
|
||||
├── n01443537
|
||||
│ ├── n01443537_10007.JPEG
|
||||
│ ├── n01443537_10014.JPEG
|
||||
│ ├── ...
|
||||
├── ...
|
||||
```
|
||||
|
||||
If you haven't extracted the data, you can also place
|
||||
`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
|
||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
|
||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
|
||||
extracted into above structure without downloading it again. Note that this
|
||||
will only happen if neither a folder
|
||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
|
||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
|
||||
if you want to force running the dataset preparation again.
|
||||
|
||||
|
||||
## Model Training
|
||||
|
||||
Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.
|
||||
|
||||
### Training autoencoder models
|
||||
|
||||
Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
|
||||
Training can be started by running
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,
|
||||
```
|
||||
where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16),
|
||||
`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
|
||||
|
||||
For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
|
||||
repository.
|
||||
|
||||
### Training LDMs
|
||||
|
||||
In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets.
|
||||
Training can be started by running
|
||||
|
||||
```shell script
|
||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
|
||||
```
|
||||
|
||||
where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
|
||||
`lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
|
||||
`lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
|
||||
|
||||
## Coming Soon...
|
||||
|
||||
* Code for training LDMs and the corresponding compression models.
|
||||
* Inference scripts for conditional LDMs for various conditioning modalities.
|
||||
* More inference scripts for conditional LDMs.
|
||||
* In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
|
||||
* We will also release some further pretrained models.
|
||||
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
model:
|
||||
base_learning_rate: 2.0e-06
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.0015
|
||||
linear_end: 0.0195
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
image_size: 64
|
||||
channels: 3
|
||||
monitor: val/loss_simple_ema
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 64
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
model_channels: 224
|
||||
attention_resolutions:
|
||||
# note: this isn\t actually the resolution but
|
||||
# the downsampling factor, i.e. this corresnponds to
|
||||
# attention on spatial resolution 8,16,32, as the
|
||||
# spatial reolution of the latents is 64 for f4
|
||||
- 8
|
||||
- 4
|
||||
- 2
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
num_head_channels: 32
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.VQModelInterface
|
||||
params:
|
||||
embed_dim: 3
|
||||
n_embed: 8192
|
||||
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
|
||||
ddconfig:
|
||||
double_z: false
|
||||
z_channels: 3
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
cond_stage_config: __is_unconditional__
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 48
|
||||
num_workers: 5
|
||||
wrap: false
|
||||
train:
|
||||
target: taming.data.faceshq.CelebAHQTrain
|
||||
params:
|
||||
size: 256
|
||||
validation:
|
||||
target: taming.data.faceshq.CelebAHQValidation
|
||||
params:
|
||||
size: 256
|
||||
|
||||
|
||||
lightning:
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 5000
|
||||
max_images: 8
|
||||
increase_log_steps: False
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
|
@ -0,0 +1,98 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-06
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.0015
|
||||
linear_end: 0.0195
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: class_label
|
||||
image_size: 32
|
||||
channels: 4
|
||||
cond_stage_trainable: true
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 256
|
||||
attention_resolutions:
|
||||
#note: this isn\t actually the resolution but
|
||||
# the downsampling factor, i.e. this corresnponds to
|
||||
# attention on spatial resolution 8,16,32, as the
|
||||
# spatial reolution of the latents is 32 for f8
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
num_head_channels: 32
|
||||
use_spatial_transformer: true
|
||||
transformer_depth: 1
|
||||
context_dim: 512
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.VQModelInterface
|
||||
params:
|
||||
embed_dim: 4
|
||||
n_embed: 16384
|
||||
ckpt_path: configs/first_stage_models/vq-f8/model.yaml
|
||||
ddconfig:
|
||||
double_z: false
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 2
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions:
|
||||
- 32
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||
params:
|
||||
embed_dim: 512
|
||||
key: class_label
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 64
|
||||
num_workers: 12
|
||||
wrap: false
|
||||
train:
|
||||
target: ldm.data.imagenet.ImageNetTrain
|
||||
params:
|
||||
config:
|
||||
size: 256
|
||||
validation:
|
||||
target: ldm.data.imagenet.ImageNetValidation
|
||||
params:
|
||||
config:
|
||||
size: 256
|
||||
|
||||
|
||||
lightning:
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 5000
|
||||
max_images: 8
|
||||
increase_log_steps: False
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
|
@ -0,0 +1,85 @@
|
|||
model:
|
||||
base_learning_rate: 2.0e-06
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.0015
|
||||
linear_end: 0.0195
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
image_size: 64
|
||||
channels: 3
|
||||
monitor: val/loss_simple_ema
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 64
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
model_channels: 224
|
||||
attention_resolutions:
|
||||
# note: this isn\t actually the resolution but
|
||||
# the downsampling factor, i.e. this corresnponds to
|
||||
# attention on spatial resolution 8,16,32, as the
|
||||
# spatial reolution of the latents is 64 for f4
|
||||
- 8
|
||||
- 4
|
||||
- 2
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
num_head_channels: 32
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.VQModelInterface
|
||||
params:
|
||||
embed_dim: 3
|
||||
n_embed: 8192
|
||||
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
||||
ddconfig:
|
||||
double_z: false
|
||||
z_channels: 3
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
cond_stage_config: __is_unconditional__
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 42
|
||||
num_workers: 5
|
||||
wrap: false
|
||||
train:
|
||||
target: taming.data.faceshq.FFHQTrain
|
||||
params:
|
||||
size: 256
|
||||
validation:
|
||||
target: taming.data.faceshq.FFHQValidation
|
||||
params:
|
||||
size: 256
|
||||
|
||||
|
||||
lightning:
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 5000
|
||||
max_images: 8
|
||||
increase_log_steps: False
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
|
@ -0,0 +1,85 @@
|
|||
model:
|
||||
base_learning_rate: 2.0e-06
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.0015
|
||||
linear_end: 0.0195
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
image_size: 64
|
||||
channels: 3
|
||||
monitor: val/loss_simple_ema
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 64
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
model_channels: 224
|
||||
attention_resolutions:
|
||||
# note: this isn\t actually the resolution but
|
||||
# the downsampling factor, i.e. this corresnponds to
|
||||
# attention on spatial resolution 8,16,32, as the
|
||||
# spatial reolution of the latents is 64 for f4
|
||||
- 8
|
||||
- 4
|
||||
- 2
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
num_head_channels: 32
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.VQModelInterface
|
||||
params:
|
||||
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
||||
embed_dim: 3
|
||||
n_embed: 8192
|
||||
ddconfig:
|
||||
double_z: false
|
||||
z_channels: 3
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
cond_stage_config: __is_unconditional__
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 48
|
||||
num_workers: 5
|
||||
wrap: false
|
||||
train:
|
||||
target: ldm.data.lsun.LSUNBedroomsTrain
|
||||
params:
|
||||
size: 256
|
||||
validation:
|
||||
target: ldm.data.lsun.LSUNBedroomsValidation
|
||||
params:
|
||||
size: 256
|
||||
|
||||
|
||||
lightning:
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 5000
|
||||
max_images: 8
|
||||
increase_log_steps: False
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
|
@ -45,7 +45,7 @@ model:
|
|||
params:
|
||||
embed_dim: 4
|
||||
monitor: "val/rec_loss"
|
||||
ckpt_path: "/export/compvis-nfs/user/ablattma/logs/braket/2021-11-26T11-25-56_lsun_churches-convae-f8-ft_from_oi/checkpoints/step=000180071-fidfrechet_inception_distance=2.335.ckpt"
|
||||
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||
ddconfig:
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
|
@ -65,7 +65,7 @@ model:
|
|||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 24 # TODO: was 96 in our experiments
|
||||
batch_size: 96
|
||||
num_workers: 5
|
||||
wrap: False
|
||||
train:
|
||||
|
@ -82,14 +82,10 @@ lightning:
|
|||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 1000 # TODO 5000
|
||||
batch_frequency: 5000
|
||||
max_images: 8
|
||||
increase_log_steps: False
|
||||
|
||||
metrics_over_trainsteps_checkpoint:
|
||||
target: pytorch_lightning.callbacks.ModelCheckpoint
|
||||
params:
|
||||
every_n_train_steps: 20000
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
|
@ -5,8 +5,7 @@ import numpy as np
|
|||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.models.diffusion.ddpm import noise_like
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
@ -27,8 +26,7 @@ class DDIMSampler(object):
|
|||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32, device=self.model.device)
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
|
@ -73,7 +71,8 @@ class DDIMSampler(object):
|
|||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100
|
||||
log_every_t=100,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
|
|
|
@ -16,14 +16,14 @@ from contextlib import contextmanager
|
|||
from functools import partial
|
||||
from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
from PIL import Image
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
|
||||
__conditioning_keys__ = {'concat': 'c_concat',
|
||||
|
@ -37,12 +37,6 @@ def disabled_train(self, mode=True):
|
|||
return self
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def uniform_on_device(r1, r2, shape, device):
|
||||
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
||||
|
||||
|
@ -119,6 +113,7 @@ class DDPM(pl.LightningModule):
|
|||
if self.learn_logvar:
|
||||
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if exists(given_betas):
|
||||
|
@ -1188,7 +1183,6 @@ class LatentDiffusion(DDPM):
|
|||
|
||||
if start_T is not None:
|
||||
timesteps = min(timesteps, start_T)
|
||||
print(timesteps, start_T)
|
||||
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
|
||||
range(0, timesteps))
|
||||
|
||||
|
@ -1222,7 +1216,7 @@ class LatentDiffusion(DDPM):
|
|||
@torch.no_grad()
|
||||
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
||||
verbose=True, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, shape=None):
|
||||
mask=None, x0=None, shape=None,**kwargs):
|
||||
if shape is None:
|
||||
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
||||
if cond is not None:
|
||||
|
@ -1238,10 +1232,28 @@ class LatentDiffusion(DDPM):
|
|||
mask=mask, x0=x0)
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, sample=True, sample_ddim=False, return_keys=None,
|
||||
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
||||
|
||||
if ddim:
|
||||
ddim_sampler = DDIMSampler(self)
|
||||
shape = (self.channels, self.image_size, self.image_size)
|
||||
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
||||
shape,cond,verbose=False,**kwargs)
|
||||
|
||||
else:
|
||||
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
||||
return_intermediates=True,**kwargs)
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
||||
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
||||
plot_diffusion_rows=True, **kwargs):
|
||||
# TODO: maybe add option for ddim sampling via DDIMSampler class
|
||||
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
log = dict()
|
||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
|
@ -1288,7 +1300,9 @@ class LatentDiffusion(DDPM):
|
|||
if sample:
|
||||
# get denoise row
|
||||
with self.ema_scope("Plotting"):
|
||||
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,eta=ddim_eta)
|
||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||
x_samples = self.decode_first_stage(samples)
|
||||
log["samples"] = x_samples
|
||||
if plot_denoise_rows:
|
||||
|
@ -1299,8 +1313,11 @@ class LatentDiffusion(DDPM):
|
|||
self.first_stage_model, IdentityFirstStage):
|
||||
# also display when quantizing x0 while sampling
|
||||
with self.ema_scope("Plotting Quantized Denoised"):
|
||||
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
||||
quantize_denoised=True)
|
||||
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,eta=ddim_eta,
|
||||
quantize_denoised=True)
|
||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
||||
# quantize_denoised=True)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_x0_quantized"] = x_samples
|
||||
|
||||
|
@ -1312,19 +1329,17 @@ class LatentDiffusion(DDPM):
|
|||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||
mask = mask[:, None, ...]
|
||||
with self.ema_scope("Plotting Inpaint"):
|
||||
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
||||
quantize_denoised=False, x0=z[:N], mask=mask)
|
||||
|
||||
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
||||
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_inpainting"] = x_samples
|
||||
log["mask"] = mask
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log["denoise_row_inpainting"] = denoise_grid
|
||||
|
||||
# outpaint
|
||||
with self.ema_scope("Plotting Outpaint"):
|
||||
samples = self.sample(cond=c, batch_size=N, return_intermediates=False,
|
||||
quantize_denoised=False, x0=z[:N], mask=1. - mask)
|
||||
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
||||
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_outpainting"] = x_samples
|
||||
|
||||
|
|
|
@ -259,3 +259,9 @@ class HybridConditioner(nn.Module):
|
|||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
5
main.py
5
main.py
|
@ -676,7 +676,10 @@ if __name__ == "__main__":
|
|||
ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
|
||||
else:
|
||||
ngpu = 1
|
||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
|
||||
if 'accumulate_grad_batches' in lightning_config.trainer:
|
||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
||||
else:
|
||||
accumulate_grad_batches = 1
|
||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
if opt.scale_lr:
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-06
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.0015
|
||||
linear_end: 0.0205
|
||||
log_every_t: 100
|
||||
timesteps: 1000
|
||||
loss_type: l1
|
||||
first_stage_key: image
|
||||
cond_stage_key: segmentation
|
||||
image_size: 64
|
||||
channels: 3
|
||||
concat_mode: true
|
||||
cond_stage_trainable: true
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 64
|
||||
in_channels: 6
|
||||
out_channels: 3
|
||||
model_channels: 128
|
||||
attention_resolutions:
|
||||
- 32
|
||||
- 16
|
||||
- 8
|
||||
num_res_blocks: 2
|
||||
channel_mult:
|
||||
- 1
|
||||
- 4
|
||||
- 8
|
||||
num_heads: 8
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.VQModelInterface
|
||||
params:
|
||||
embed_dim: 3
|
||||
n_embed: 8192
|
||||
ddconfig:
|
||||
double_z: false
|
||||
z_channels: 3
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.SpatialRescaler
|
||||
params:
|
||||
n_stages: 2
|
||||
in_channels: 182
|
||||
out_channels: 3
|
|
@ -4,10 +4,10 @@ wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/la
|
|||
wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
|
||||
wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
|
||||
wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
|
||||
wget -O models/first_stage_models/vq-f4-noattn/model.zip https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1
|
||||
wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
|
||||
wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
|
||||
wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
|
||||
wget -O models/first_stage_models/vq-f16/model.zip https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1
|
||||
wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -6,9 +6,10 @@ wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/la
|
|||
wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
|
||||
wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
|
||||
wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
|
||||
wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
|
||||
wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
|
||||
wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
|
||||
wget -O models/ldm/inpainting_big/last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1
|
||||
wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
|
||||
|
||||
|
||||
|
||||
|
@ -33,10 +34,16 @@ unzip -o model.zip
|
|||
cd ../semantic_synthesis512
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../semantic_synthesis256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../bsr_sr
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../layout2img-openimages256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../inpainting_big
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../..
|
||||
|
|
Loading…
Reference in New Issue