initial commit
This commit is contained in:
commit
8f22429d74
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,54 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 16
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [16]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
|
@ -0,0 +1,53 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 4
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
|
@ -0,0 +1,54 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 3
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
|
@ -0,0 +1,53 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 64
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 64
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [16,8]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
|
@ -0,0 +1,86 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 48
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: taming.data.faceshq.CelebAHQTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: taming.data.faceshq.CelebAHQValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,98 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 256
|
||||||
|
attention_resolutions:
|
||||||
|
#note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 32 for f8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
n_embed: 16384
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f8/model.yaml
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions:
|
||||||
|
- 32
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 12
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetTrain
|
||||||
|
params:
|
||||||
|
config:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetValidation
|
||||||
|
params:
|
||||||
|
config:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,68 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions:
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
num_heads: 1
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
n_classes: 1001
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
|
@ -0,0 +1,85 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 42
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: taming.data.faceshq.FFHQTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: taming.data.faceshq.FFHQValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,85 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 48
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.lsun.LSUNBedroomsTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.lsun.LSUNBedroomsValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,91 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0155
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
loss_type: l1
|
||||||
|
first_stage_key: "image"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: False
|
||||||
|
concat_mode: False
|
||||||
|
scale_by_std: True
|
||||||
|
monitor: 'val/loss_simple_ema'
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [10000]
|
||||||
|
cycle_lengths: [10000000000000]
|
||||||
|
f_start: [1.e-6]
|
||||||
|
f_max: [1.]
|
||||||
|
f_min: [ 1.]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
|
||||||
|
num_heads: 8
|
||||||
|
use_scale_shift_norm: True
|
||||||
|
resblock_updown: True
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config: "__is_unconditional__"
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 96
|
||||||
|
num_workers: 5
|
||||||
|
wrap: False
|
||||||
|
train:
|
||||||
|
target: ldm.data.lsun.LSUNChurchesTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.lsun.LSUNChurchesValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,71 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
|
@ -0,0 +1,77 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: []
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
|
@ -0,0 +1,119 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-3
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
embedding_reg_weight: 0.0
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ["sculpture"]
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 2
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
set: train
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 100
|
||||||
|
validation:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
set: val
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 10
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 500
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
max_steps: 6100
|
|
@ -0,0 +1,117 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-3
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
embedding_reg_weight: 0.0
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ["painting"]
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.personalized_style.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
set: train
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 100
|
||||||
|
validation:
|
||||||
|
target: ldm.data.personalized_style.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
set: val
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 10
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 500
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,110 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-03
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
embedding_reg_weight: 0.0
|
||||||
|
unfreeze_model: False
|
||||||
|
model_lr: 0.0
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ["sculpture"]
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 512
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 2
|
||||||
|
num_workers: 2
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
set: train
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 100
|
||||||
|
validation:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
set: val
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 10
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 500
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
max_steps: 6100
|
|
@ -0,0 +1,120 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
reg_weight: 1.0
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
embedding_reg_weight: 0.0
|
||||||
|
unfreeze_model: True
|
||||||
|
model_lr: 1.0e-5
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ["sculpture"]
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 512
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 2
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
set: train
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 100
|
||||||
|
reg:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
set: train
|
||||||
|
reg: true
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 100
|
||||||
|
|
||||||
|
validation:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
set: val
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 10
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 500
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
max_steps: 4200
|
|
@ -0,0 +1,70 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ["sculpture"]
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
@ -0,0 +1,31 @@
|
||||||
|
name: ldm
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.8.10
|
||||||
|
- pip=20.3
|
||||||
|
- cudatoolkit=11.3
|
||||||
|
- pytorch=1.10.2
|
||||||
|
- torchvision=0.11.3
|
||||||
|
- numpy=1.22.3
|
||||||
|
- pip:
|
||||||
|
- albumentations==1.1.0
|
||||||
|
- opencv-python==4.2.0.34
|
||||||
|
- pudb==2019.2
|
||||||
|
- imageio==2.14.1
|
||||||
|
- imageio-ffmpeg==0.4.7
|
||||||
|
- pytorch-lightning==1.5.9
|
||||||
|
- omegaconf==2.1.1
|
||||||
|
- test-tube>=0.7.5
|
||||||
|
- streamlit>=0.73.1
|
||||||
|
- setuptools==59.5.0
|
||||||
|
- pillow==9.0.1
|
||||||
|
- einops==0.4.1
|
||||||
|
- torch-fidelity==0.3.0
|
||||||
|
- transformers==4.18.0
|
||||||
|
- torchmetrics==0.6.0
|
||||||
|
- kornia==0.6
|
||||||
|
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
|
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
|
- -e .
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,113 @@
|
||||||
|
import clip
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
|
class CLIPEvaluator(object):
|
||||||
|
def __init__(self, device, clip_model='ViT-B/32') -> None:
|
||||||
|
self.device = device
|
||||||
|
self.model, clip_preprocess = clip.load(clip_model, device=self.device)
|
||||||
|
|
||||||
|
self.clip_preprocess = clip_preprocess
|
||||||
|
|
||||||
|
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (generator output) to [0, 1].
|
||||||
|
clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
|
||||||
|
clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
|
||||||
|
|
||||||
|
def tokenize(self, strings: list):
|
||||||
|
return clip.tokenize(strings).to(self.device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode_text(self, tokens: list) -> torch.Tensor:
|
||||||
|
return self.model.encode_text(tokens)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
|
||||||
|
images = self.preprocess(images).to(self.device)
|
||||||
|
return self.model.encode_image(images)
|
||||||
|
|
||||||
|
def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor:
|
||||||
|
|
||||||
|
tokens = clip.tokenize(text).to(self.device)
|
||||||
|
|
||||||
|
text_features = self.encode_text(tokens).detach()
|
||||||
|
|
||||||
|
if norm:
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return text_features
|
||||||
|
|
||||||
|
def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
|
||||||
|
image_features = self.encode_images(img)
|
||||||
|
|
||||||
|
if norm:
|
||||||
|
image_features /= image_features.clone().norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def img_to_img_similarity(self, src_images, generated_images):
|
||||||
|
src_img_features = self.get_image_features(src_images)
|
||||||
|
gen_img_features = self.get_image_features(generated_images)
|
||||||
|
|
||||||
|
return (src_img_features @ gen_img_features.T).mean()
|
||||||
|
|
||||||
|
def txt_to_img_similarity(self, text, generated_images):
|
||||||
|
text_features = self.get_text_features(text)
|
||||||
|
gen_img_features = self.get_image_features(generated_images)
|
||||||
|
|
||||||
|
return (text_features @ gen_img_features.T).mean()
|
||||||
|
|
||||||
|
|
||||||
|
class LDMCLIPEvaluator(CLIPEvaluator):
|
||||||
|
def __init__(self, device, clip_model='ViT-B/32') -> None:
|
||||||
|
super().__init__(device, clip_model)
|
||||||
|
|
||||||
|
def evaluate(self, ldm_model, src_images, target_text, n_samples=64, n_steps=50):
|
||||||
|
|
||||||
|
sampler = DDIMSampler(ldm_model)
|
||||||
|
|
||||||
|
samples_per_batch = 8
|
||||||
|
n_batches = n_samples // samples_per_batch
|
||||||
|
|
||||||
|
# generate samples
|
||||||
|
all_samples=list()
|
||||||
|
with torch.no_grad():
|
||||||
|
with ldm_model.ema_scope():
|
||||||
|
uc = ldm_model.get_learned_conditioning(samples_per_batch * [""])
|
||||||
|
|
||||||
|
for batch in range(n_batches):
|
||||||
|
c = ldm_model.get_learned_conditioning(samples_per_batch * [target_text])
|
||||||
|
shape = [4, 256//8, 256//8]
|
||||||
|
samples_ddim, _ = sampler.sample(S=n_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=samples_per_batch,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=5.0,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=0.0)
|
||||||
|
|
||||||
|
x_samples_ddim = ldm_model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
|
||||||
|
|
||||||
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
|
all_samples = torch.cat(all_samples, axis=0)
|
||||||
|
|
||||||
|
sim_samples_to_img = self.img_to_img_similarity(src_images, all_samples)
|
||||||
|
sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), all_samples)
|
||||||
|
|
||||||
|
return sim_samples_to_img, sim_samples_to_text
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDirEvaluator(CLIPEvaluator):
|
||||||
|
def __init__(self, device, clip_model='ViT-B/32') -> None:
|
||||||
|
super().__init__(device, clip_model)
|
||||||
|
|
||||||
|
def evaluate(self, gen_samples, src_images, target_text):
|
||||||
|
|
||||||
|
sim_samples_to_img = self.img_to_img_similarity(src_images, gen_samples)
|
||||||
|
sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), gen_samples)
|
||||||
|
|
||||||
|
return sim_samples_to_img, sim_samples_to_text
|
Binary file not shown.
After Width: | Height: | Size: 163 KiB |
Binary file not shown.
After Width: | Height: | Size: 192 KiB |
Binary file not shown.
After Width: | Height: | Size: 165 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,23 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||||
|
'''
|
||||||
|
Define an interface to make the IterableDatasets for text2img data chainable
|
||||||
|
'''
|
||||||
|
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||||
|
super().__init__()
|
||||||
|
self.num_records = num_records
|
||||||
|
self.valid_ids = valid_ids
|
||||||
|
self.sample_ids = valid_ids
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_records
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __iter__(self):
|
||||||
|
pass
|
|
@ -0,0 +1,394 @@
|
||||||
|
import os, yaml, pickle, shutil, tarfile, glob
|
||||||
|
import cv2
|
||||||
|
import albumentations
|
||||||
|
import PIL
|
||||||
|
import numpy as np
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from functools import partial
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import Dataset, Subset
|
||||||
|
|
||||||
|
import taming.data.utils as tdu
|
||||||
|
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
||||||
|
from taming.data.imagenet import ImagePaths
|
||||||
|
|
||||||
|
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
||||||
|
|
||||||
|
|
||||||
|
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
||||||
|
with open(path_to_yaml) as f:
|
||||||
|
di2s = yaml.load(f)
|
||||||
|
return dict((v,k) for k,v in di2s.items())
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetBase(Dataset):
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = config or OmegaConf.create()
|
||||||
|
if not type(self.config)==dict:
|
||||||
|
self.config = OmegaConf.to_container(self.config)
|
||||||
|
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
||||||
|
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
||||||
|
self._prepare()
|
||||||
|
self._prepare_synset_to_human()
|
||||||
|
self._prepare_idx_to_synset()
|
||||||
|
self._prepare_human_to_integer_label()
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.data[i]
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _filter_relpaths(self, relpaths):
|
||||||
|
ignore = set([
|
||||||
|
"n06596364_9591.JPEG",
|
||||||
|
])
|
||||||
|
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
||||||
|
if "sub_indices" in self.config:
|
||||||
|
indices = str_to_indices(self.config["sub_indices"])
|
||||||
|
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
||||||
|
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
||||||
|
files = []
|
||||||
|
for rpath in relpaths:
|
||||||
|
syn = rpath.split("/")[0]
|
||||||
|
if syn in synsets:
|
||||||
|
files.append(rpath)
|
||||||
|
return files
|
||||||
|
else:
|
||||||
|
return relpaths
|
||||||
|
|
||||||
|
def _prepare_synset_to_human(self):
|
||||||
|
SIZE = 2655750
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
||||||
|
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
||||||
|
if (not os.path.exists(self.human_dict) or
|
||||||
|
not os.path.getsize(self.human_dict)==SIZE):
|
||||||
|
download(URL, self.human_dict)
|
||||||
|
|
||||||
|
def _prepare_idx_to_synset(self):
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
||||||
|
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
||||||
|
if (not os.path.exists(self.idx2syn)):
|
||||||
|
download(URL, self.idx2syn)
|
||||||
|
|
||||||
|
def _prepare_human_to_integer_label(self):
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
||||||
|
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
||||||
|
if (not os.path.exists(self.human2integer)):
|
||||||
|
download(URL, self.human2integer)
|
||||||
|
with open(self.human2integer, "r") as f:
|
||||||
|
lines = f.read().splitlines()
|
||||||
|
assert len(lines) == 1000
|
||||||
|
self.human2integer_dict = dict()
|
||||||
|
for line in lines:
|
||||||
|
value, key = line.split(":")
|
||||||
|
self.human2integer_dict[key] = int(value)
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
with open(self.txt_filelist, "r") as f:
|
||||||
|
self.relpaths = f.read().splitlines()
|
||||||
|
l1 = len(self.relpaths)
|
||||||
|
self.relpaths = self._filter_relpaths(self.relpaths)
|
||||||
|
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
||||||
|
|
||||||
|
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
||||||
|
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
||||||
|
|
||||||
|
unique_synsets = np.unique(self.synsets)
|
||||||
|
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
||||||
|
if not self.keep_orig_class_label:
|
||||||
|
self.class_labels = [class_dict[s] for s in self.synsets]
|
||||||
|
else:
|
||||||
|
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
||||||
|
|
||||||
|
with open(self.human_dict, "r") as f:
|
||||||
|
human_dict = f.read().splitlines()
|
||||||
|
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
||||||
|
|
||||||
|
self.human_labels = [human_dict[s] for s in self.synsets]
|
||||||
|
|
||||||
|
labels = {
|
||||||
|
"relpath": np.array(self.relpaths),
|
||||||
|
"synsets": np.array(self.synsets),
|
||||||
|
"class_label": np.array(self.class_labels),
|
||||||
|
"human_label": np.array(self.human_labels),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.process_images:
|
||||||
|
self.size = retrieve(self.config, "size", default=256)
|
||||||
|
self.data = ImagePaths(self.abspaths,
|
||||||
|
labels=labels,
|
||||||
|
size=self.size,
|
||||||
|
random_crop=self.random_crop,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data = self.abspaths
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetTrain(ImageNetBase):
|
||||||
|
NAME = "ILSVRC2012_train"
|
||||||
|
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||||
|
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
||||||
|
FILES = [
|
||||||
|
"ILSVRC2012_img_train.tar",
|
||||||
|
]
|
||||||
|
SIZES = [
|
||||||
|
147897477120,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, process_images=True, data_root=None, **kwargs):
|
||||||
|
self.process_images = process_images
|
||||||
|
self.data_root = data_root
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
if self.data_root:
|
||||||
|
self.root = os.path.join(self.data_root, self.NAME)
|
||||||
|
else:
|
||||||
|
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||||
|
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||||
|
|
||||||
|
self.datadir = os.path.join(self.root, "data")
|
||||||
|
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||||
|
self.expected_length = 1281167
|
||||||
|
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
||||||
|
default=True)
|
||||||
|
if not tdu.is_prepared(self.root):
|
||||||
|
# prep
|
||||||
|
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||||
|
|
||||||
|
datadir = self.datadir
|
||||||
|
if not os.path.exists(datadir):
|
||||||
|
path = os.path.join(self.root, self.FILES[0])
|
||||||
|
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||||
|
import academictorrents as at
|
||||||
|
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||||
|
assert atpath == path
|
||||||
|
|
||||||
|
print("Extracting {} to {}".format(path, datadir))
|
||||||
|
os.makedirs(datadir, exist_ok=True)
|
||||||
|
with tarfile.open(path, "r:") as tar:
|
||||||
|
tar.extractall(path=datadir)
|
||||||
|
|
||||||
|
print("Extracting sub-tars.")
|
||||||
|
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
||||||
|
for subpath in tqdm(subpaths):
|
||||||
|
subdir = subpath[:-len(".tar")]
|
||||||
|
os.makedirs(subdir, exist_ok=True)
|
||||||
|
with tarfile.open(subpath, "r:") as tar:
|
||||||
|
tar.extractall(path=subdir)
|
||||||
|
|
||||||
|
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||||
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||||
|
filelist = sorted(filelist)
|
||||||
|
filelist = "\n".join(filelist)+"\n"
|
||||||
|
with open(self.txt_filelist, "w") as f:
|
||||||
|
f.write(filelist)
|
||||||
|
|
||||||
|
tdu.mark_prepared(self.root)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetValidation(ImageNetBase):
|
||||||
|
NAME = "ILSVRC2012_validation"
|
||||||
|
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||||
|
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
||||||
|
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
||||||
|
FILES = [
|
||||||
|
"ILSVRC2012_img_val.tar",
|
||||||
|
"validation_synset.txt",
|
||||||
|
]
|
||||||
|
SIZES = [
|
||||||
|
6744924160,
|
||||||
|
1950000,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, process_images=True, data_root=None, **kwargs):
|
||||||
|
self.data_root = data_root
|
||||||
|
self.process_images = process_images
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
if self.data_root:
|
||||||
|
self.root = os.path.join(self.data_root, self.NAME)
|
||||||
|
else:
|
||||||
|
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||||
|
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||||
|
self.datadir = os.path.join(self.root, "data")
|
||||||
|
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||||
|
self.expected_length = 50000
|
||||||
|
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
||||||
|
default=False)
|
||||||
|
if not tdu.is_prepared(self.root):
|
||||||
|
# prep
|
||||||
|
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||||
|
|
||||||
|
datadir = self.datadir
|
||||||
|
if not os.path.exists(datadir):
|
||||||
|
path = os.path.join(self.root, self.FILES[0])
|
||||||
|
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||||
|
import academictorrents as at
|
||||||
|
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||||
|
assert atpath == path
|
||||||
|
|
||||||
|
print("Extracting {} to {}".format(path, datadir))
|
||||||
|
os.makedirs(datadir, exist_ok=True)
|
||||||
|
with tarfile.open(path, "r:") as tar:
|
||||||
|
tar.extractall(path=datadir)
|
||||||
|
|
||||||
|
vspath = os.path.join(self.root, self.FILES[1])
|
||||||
|
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
||||||
|
download(self.VS_URL, vspath)
|
||||||
|
|
||||||
|
with open(vspath, "r") as f:
|
||||||
|
synset_dict = f.read().splitlines()
|
||||||
|
synset_dict = dict(line.split() for line in synset_dict)
|
||||||
|
|
||||||
|
print("Reorganizing into synset folders")
|
||||||
|
synsets = np.unique(list(synset_dict.values()))
|
||||||
|
for s in synsets:
|
||||||
|
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
||||||
|
for k, v in synset_dict.items():
|
||||||
|
src = os.path.join(datadir, k)
|
||||||
|
dst = os.path.join(datadir, v)
|
||||||
|
shutil.move(src, dst)
|
||||||
|
|
||||||
|
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||||
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||||
|
filelist = sorted(filelist)
|
||||||
|
filelist = "\n".join(filelist)+"\n"
|
||||||
|
with open(self.txt_filelist, "w") as f:
|
||||||
|
f.write(filelist)
|
||||||
|
|
||||||
|
tdu.mark_prepared(self.root)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSR(Dataset):
|
||||||
|
def __init__(self, size=None,
|
||||||
|
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
||||||
|
random_crop=True):
|
||||||
|
"""
|
||||||
|
Imagenet Superresolution Dataloader
|
||||||
|
Performs following ops in order:
|
||||||
|
1. crops a crop of size s from image either as random or center crop
|
||||||
|
2. resizes crop to size with cv2.area_interpolation
|
||||||
|
3. degrades resized crop with degradation_fn
|
||||||
|
|
||||||
|
:param size: resizing to size after cropping
|
||||||
|
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
||||||
|
:param downscale_f: Low Resolution Downsample factor
|
||||||
|
:param min_crop_f: determines crop size s,
|
||||||
|
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
||||||
|
:param max_crop_f: ""
|
||||||
|
:param data_root:
|
||||||
|
:param random_crop:
|
||||||
|
"""
|
||||||
|
self.base = self.get_base()
|
||||||
|
assert size
|
||||||
|
assert (size / downscale_f).is_integer()
|
||||||
|
self.size = size
|
||||||
|
self.LR_size = int(size / downscale_f)
|
||||||
|
self.min_crop_f = min_crop_f
|
||||||
|
self.max_crop_f = max_crop_f
|
||||||
|
assert(max_crop_f <= 1.)
|
||||||
|
self.center_crop = not random_crop
|
||||||
|
|
||||||
|
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
||||||
|
|
||||||
|
if degradation == "bsrgan":
|
||||||
|
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
||||||
|
|
||||||
|
elif degradation == "bsrgan_light":
|
||||||
|
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
||||||
|
|
||||||
|
else:
|
||||||
|
interpolation_fn = {
|
||||||
|
"cv_nearest": cv2.INTER_NEAREST,
|
||||||
|
"cv_bilinear": cv2.INTER_LINEAR,
|
||||||
|
"cv_bicubic": cv2.INTER_CUBIC,
|
||||||
|
"cv_area": cv2.INTER_AREA,
|
||||||
|
"cv_lanczos": cv2.INTER_LANCZOS4,
|
||||||
|
"pil_nearest": PIL.Image.NEAREST,
|
||||||
|
"pil_bilinear": PIL.Image.BILINEAR,
|
||||||
|
"pil_bicubic": PIL.Image.BICUBIC,
|
||||||
|
"pil_box": PIL.Image.BOX,
|
||||||
|
"pil_hamming": PIL.Image.HAMMING,
|
||||||
|
"pil_lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[degradation]
|
||||||
|
|
||||||
|
self.pil_interpolation = degradation.startswith("pil_")
|
||||||
|
|
||||||
|
if self.pil_interpolation:
|
||||||
|
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
||||||
|
interpolation=interpolation_fn)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.base)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = self.base[i]
|
||||||
|
image = Image.open(example["file_path_"])
|
||||||
|
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
|
||||||
|
min_side_len = min(image.shape[:2])
|
||||||
|
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
||||||
|
crop_side_len = int(crop_side_len)
|
||||||
|
|
||||||
|
if self.center_crop:
|
||||||
|
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
||||||
|
|
||||||
|
image = self.cropper(image=image)["image"]
|
||||||
|
image = self.image_rescaler(image=image)["image"]
|
||||||
|
|
||||||
|
if self.pil_interpolation:
|
||||||
|
image_pil = PIL.Image.fromarray(image)
|
||||||
|
LR_image = self.degradation_process(image_pil)
|
||||||
|
LR_image = np.array(LR_image).astype(np.uint8)
|
||||||
|
|
||||||
|
else:
|
||||||
|
LR_image = self.degradation_process(image=image)["image"]
|
||||||
|
|
||||||
|
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
||||||
|
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSRTrain(ImageNetSR):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_base(self):
|
||||||
|
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
||||||
|
indices = pickle.load(f)
|
||||||
|
dset = ImageNetTrain(process_images=False,)
|
||||||
|
return Subset(dset, indices)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSRValidation(ImageNetSR):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_base(self):
|
||||||
|
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
||||||
|
indices = pickle.load(f)
|
||||||
|
dset = ImageNetValidation(process_images=False,)
|
||||||
|
return Subset(dset, indices)
|
|
@ -0,0 +1,92 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBase(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
txt_file,
|
||||||
|
data_root,
|
||||||
|
size=None,
|
||||||
|
interpolation="bicubic",
|
||||||
|
flip_p=0.5
|
||||||
|
):
|
||||||
|
self.data_paths = txt_file
|
||||||
|
self.data_root = data_root
|
||||||
|
with open(self.data_paths, "r") as f:
|
||||||
|
self.image_paths = f.read().splitlines()
|
||||||
|
self._length = len(self.image_paths)
|
||||||
|
self.labels = {
|
||||||
|
"relative_file_path_": [l for l in self.image_paths],
|
||||||
|
"file_path_": [os.path.join(self.data_root, l)
|
||||||
|
for l in self.image_paths],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||||
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = dict((k, self.labels[k][i]) for k in self.labels)
|
||||||
|
image = Image.open(example["file_path_"])
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = img.shape[0], img.shape[1]
|
||||||
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
|
(w - crop) // 2:(w + crop) // 2]
|
||||||
|
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNChurchesTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNChurchesValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0., **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
||||||
|
flip_p=flip_p, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBedroomsTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBedroomsValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0.0, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
||||||
|
flip_p=flip_p, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNCatsTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNCatsValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0., **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
||||||
|
flip_p=flip_p, **kwargs)
|
|
@ -0,0 +1,220 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
training_templates_smallest = [
|
||||||
|
'photo of a sks {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
reg_templates_smallest = [
|
||||||
|
'photo of a {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_templates_small = [
|
||||||
|
'a photo of a {}',
|
||||||
|
'a rendering of a {}',
|
||||||
|
'a cropped photo of the {}',
|
||||||
|
'the photo of a {}',
|
||||||
|
'a photo of a clean {}',
|
||||||
|
'a photo of a dirty {}',
|
||||||
|
'a dark photo of the {}',
|
||||||
|
'a photo of my {}',
|
||||||
|
'a photo of the cool {}',
|
||||||
|
'a close-up photo of a {}',
|
||||||
|
'a bright photo of the {}',
|
||||||
|
'a cropped photo of a {}',
|
||||||
|
'a photo of the {}',
|
||||||
|
'a good photo of the {}',
|
||||||
|
'a photo of one {}',
|
||||||
|
'a close-up photo of the {}',
|
||||||
|
'a rendition of the {}',
|
||||||
|
'a photo of the clean {}',
|
||||||
|
'a rendition of a {}',
|
||||||
|
'a photo of a nice {}',
|
||||||
|
'a good photo of a {}',
|
||||||
|
'a photo of the nice {}',
|
||||||
|
'a photo of the small {}',
|
||||||
|
'a photo of the weird {}',
|
||||||
|
'a photo of the large {}',
|
||||||
|
'a photo of a cool {}',
|
||||||
|
'a photo of a small {}',
|
||||||
|
'an illustration of a {}',
|
||||||
|
'a rendering of a {}',
|
||||||
|
'a cropped photo of the {}',
|
||||||
|
'the photo of a {}',
|
||||||
|
'an illustration of a clean {}',
|
||||||
|
'an illustration of a dirty {}',
|
||||||
|
'a dark photo of the {}',
|
||||||
|
'an illustration of my {}',
|
||||||
|
'an illustration of the cool {}',
|
||||||
|
'a close-up photo of a {}',
|
||||||
|
'a bright photo of the {}',
|
||||||
|
'a cropped photo of a {}',
|
||||||
|
'an illustration of the {}',
|
||||||
|
'a good photo of the {}',
|
||||||
|
'an illustration of one {}',
|
||||||
|
'a close-up photo of the {}',
|
||||||
|
'a rendition of the {}',
|
||||||
|
'an illustration of the clean {}',
|
||||||
|
'a rendition of a {}',
|
||||||
|
'an illustration of a nice {}',
|
||||||
|
'a good photo of a {}',
|
||||||
|
'an illustration of the nice {}',
|
||||||
|
'an illustration of the small {}',
|
||||||
|
'an illustration of the weird {}',
|
||||||
|
'an illustration of the large {}',
|
||||||
|
'an illustration of a cool {}',
|
||||||
|
'an illustration of a small {}',
|
||||||
|
'a depiction of a {}',
|
||||||
|
'a rendering of a {}',
|
||||||
|
'a cropped photo of the {}',
|
||||||
|
'the photo of a {}',
|
||||||
|
'a depiction of a clean {}',
|
||||||
|
'a depiction of a dirty {}',
|
||||||
|
'a dark photo of the {}',
|
||||||
|
'a depiction of my {}',
|
||||||
|
'a depiction of the cool {}',
|
||||||
|
'a close-up photo of a {}',
|
||||||
|
'a bright photo of the {}',
|
||||||
|
'a cropped photo of a {}',
|
||||||
|
'a depiction of the {}',
|
||||||
|
'a good photo of the {}',
|
||||||
|
'a depiction of one {}',
|
||||||
|
'a close-up photo of the {}',
|
||||||
|
'a rendition of the {}',
|
||||||
|
'a depiction of the clean {}',
|
||||||
|
'a rendition of a {}',
|
||||||
|
'a depiction of a nice {}',
|
||||||
|
'a good photo of a {}',
|
||||||
|
'a depiction of the nice {}',
|
||||||
|
'a depiction of the small {}',
|
||||||
|
'a depiction of the weird {}',
|
||||||
|
'a depiction of the large {}',
|
||||||
|
'a depiction of a cool {}',
|
||||||
|
'a depiction of a small {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_dual_templates_small = [
|
||||||
|
'a photo of a {} with {}',
|
||||||
|
'a rendering of a {} with {}',
|
||||||
|
'a cropped photo of the {} with {}',
|
||||||
|
'the photo of a {} with {}',
|
||||||
|
'a photo of a clean {} with {}',
|
||||||
|
'a photo of a dirty {} with {}',
|
||||||
|
'a dark photo of the {} with {}',
|
||||||
|
'a photo of my {} with {}',
|
||||||
|
'a photo of the cool {} with {}',
|
||||||
|
'a close-up photo of a {} with {}',
|
||||||
|
'a bright photo of the {} with {}',
|
||||||
|
'a cropped photo of a {} with {}',
|
||||||
|
'a photo of the {} with {}',
|
||||||
|
'a good photo of the {} with {}',
|
||||||
|
'a photo of one {} with {}',
|
||||||
|
'a close-up photo of the {} with {}',
|
||||||
|
'a rendition of the {} with {}',
|
||||||
|
'a photo of the clean {} with {}',
|
||||||
|
'a rendition of a {} with {}',
|
||||||
|
'a photo of a nice {} with {}',
|
||||||
|
'a good photo of a {} with {}',
|
||||||
|
'a photo of the nice {} with {}',
|
||||||
|
'a photo of the small {} with {}',
|
||||||
|
'a photo of the weird {} with {}',
|
||||||
|
'a photo of the large {} with {}',
|
||||||
|
'a photo of a cool {} with {}',
|
||||||
|
'a photo of a small {} with {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
per_img_token_list = [
|
||||||
|
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
|
||||||
|
]
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
data_root,
|
||||||
|
size=None,
|
||||||
|
repeats=100,
|
||||||
|
interpolation="bicubic",
|
||||||
|
flip_p=0.5,
|
||||||
|
set="train",
|
||||||
|
placeholder_token="dog",
|
||||||
|
per_image_tokens=False,
|
||||||
|
center_crop=False,
|
||||||
|
mixing_prob=0.25,
|
||||||
|
coarse_class_text=None,
|
||||||
|
reg = False
|
||||||
|
):
|
||||||
|
|
||||||
|
self.data_root = data_root
|
||||||
|
|
||||||
|
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||||
|
|
||||||
|
# self._length = len(self.image_paths)
|
||||||
|
self.num_images = len(self.image_paths)
|
||||||
|
self._length = self.num_images
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.per_image_tokens = per_image_tokens
|
||||||
|
self.center_crop = center_crop
|
||||||
|
self.mixing_prob = mixing_prob
|
||||||
|
|
||||||
|
self.coarse_class_text = coarse_class_text
|
||||||
|
|
||||||
|
if per_image_tokens:
|
||||||
|
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
||||||
|
|
||||||
|
if set == "train":
|
||||||
|
self._length = self.num_images * repeats
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||||
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
self.reg = reg
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = {}
|
||||||
|
image = Image.open(self.image_paths[i % self.num_images])
|
||||||
|
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
placeholder_string = self.placeholder_token
|
||||||
|
if self.coarse_class_text:
|
||||||
|
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
|
||||||
|
|
||||||
|
if not self.reg:
|
||||||
|
text = random.choice(training_templates_smallest).format(placeholder_string)
|
||||||
|
else:
|
||||||
|
text = random.choice(reg_templates_smallest).format(placeholder_string)
|
||||||
|
|
||||||
|
example["caption"] = text
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
|
||||||
|
if self.center_crop:
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = img.shape[0], img.shape[1]
|
||||||
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
|
(w - crop) // 2:(w + crop) // 2]
|
||||||
|
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return example
|
|
@ -0,0 +1,129 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
imagenet_templates_small = [
|
||||||
|
'a painting in the style of {}',
|
||||||
|
'a rendering in the style of {}',
|
||||||
|
'a cropped painting in the style of {}',
|
||||||
|
'the painting in the style of {}',
|
||||||
|
'a clean painting in the style of {}',
|
||||||
|
'a dirty painting in the style of {}',
|
||||||
|
'a dark painting in the style of {}',
|
||||||
|
'a picture in the style of {}',
|
||||||
|
'a cool painting in the style of {}',
|
||||||
|
'a close-up painting in the style of {}',
|
||||||
|
'a bright painting in the style of {}',
|
||||||
|
'a cropped painting in the style of {}',
|
||||||
|
'a good painting in the style of {}',
|
||||||
|
'a close-up painting in the style of {}',
|
||||||
|
'a rendition in the style of {}',
|
||||||
|
'a nice painting in the style of {}',
|
||||||
|
'a small painting in the style of {}',
|
||||||
|
'a weird painting in the style of {}',
|
||||||
|
'a large painting in the style of {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_dual_templates_small = [
|
||||||
|
'a painting in the style of {} with {}',
|
||||||
|
'a rendering in the style of {} with {}',
|
||||||
|
'a cropped painting in the style of {} with {}',
|
||||||
|
'the painting in the style of {} with {}',
|
||||||
|
'a clean painting in the style of {} with {}',
|
||||||
|
'a dirty painting in the style of {} with {}',
|
||||||
|
'a dark painting in the style of {} with {}',
|
||||||
|
'a cool painting in the style of {} with {}',
|
||||||
|
'a close-up painting in the style of {} with {}',
|
||||||
|
'a bright painting in the style of {} with {}',
|
||||||
|
'a cropped painting in the style of {} with {}',
|
||||||
|
'a good painting in the style of {} with {}',
|
||||||
|
'a painting of one {} in the style of {}',
|
||||||
|
'a nice painting in the style of {} with {}',
|
||||||
|
'a small painting in the style of {} with {}',
|
||||||
|
'a weird painting in the style of {} with {}',
|
||||||
|
'a large painting in the style of {} with {}',
|
||||||
|
]
|
||||||
|
|
||||||
|
per_img_token_list = [
|
||||||
|
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
|
||||||
|
]
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
data_root,
|
||||||
|
size=None,
|
||||||
|
repeats=100,
|
||||||
|
interpolation="bicubic",
|
||||||
|
flip_p=0.5,
|
||||||
|
set="train",
|
||||||
|
placeholder_token="*",
|
||||||
|
per_image_tokens=False,
|
||||||
|
center_crop=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.data_root = data_root
|
||||||
|
|
||||||
|
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||||
|
|
||||||
|
# self._length = len(self.image_paths)
|
||||||
|
self.num_images = len(self.image_paths)
|
||||||
|
self._length = self.num_images
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.per_image_tokens = per_image_tokens
|
||||||
|
self.center_crop = center_crop
|
||||||
|
|
||||||
|
if per_image_tokens:
|
||||||
|
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
||||||
|
|
||||||
|
if set == "train":
|
||||||
|
self._length = self.num_images * repeats
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||||
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = {}
|
||||||
|
image = Image.open(self.image_paths[i % self.num_images])
|
||||||
|
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
if self.per_image_tokens and np.random.uniform() < 0.25:
|
||||||
|
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
|
||||||
|
else:
|
||||||
|
text = random.choice(imagenet_templates_small).format(self.placeholder_token)
|
||||||
|
|
||||||
|
example["caption"] = text
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
|
||||||
|
if self.center_crop:
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = img.shape[0], img.shape[1]
|
||||||
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
|
(w - crop) // 2:(w + crop) // 2]
|
||||||
|
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return example
|
|
@ -0,0 +1,98 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaWarmUpCosineScheduler:
|
||||||
|
"""
|
||||||
|
note: use with a base_lr of 1.0
|
||||||
|
"""
|
||||||
|
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
||||||
|
self.lr_warm_up_steps = warm_up_steps
|
||||||
|
self.lr_start = lr_start
|
||||||
|
self.lr_min = lr_min
|
||||||
|
self.lr_max = lr_max
|
||||||
|
self.lr_max_decay_steps = max_decay_steps
|
||||||
|
self.last_lr = 0.
|
||||||
|
self.verbosity_interval = verbosity_interval
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||||
|
if n < self.lr_warm_up_steps:
|
||||||
|
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
||||||
|
self.last_lr = lr
|
||||||
|
return lr
|
||||||
|
else:
|
||||||
|
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||||
|
t = min(t, 1.0)
|
||||||
|
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||||
|
1 + np.cos(t * np.pi))
|
||||||
|
self.last_lr = lr
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def __call__(self, n, **kwargs):
|
||||||
|
return self.schedule(n,**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaWarmUpCosineScheduler2:
|
||||||
|
"""
|
||||||
|
supports repeated iterations, configurable via lists
|
||||||
|
note: use with a base_lr of 1.0.
|
||||||
|
"""
|
||||||
|
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
||||||
|
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
||||||
|
self.lr_warm_up_steps = warm_up_steps
|
||||||
|
self.f_start = f_start
|
||||||
|
self.f_min = f_min
|
||||||
|
self.f_max = f_max
|
||||||
|
self.cycle_lengths = cycle_lengths
|
||||||
|
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||||
|
self.last_f = 0.
|
||||||
|
self.verbosity_interval = verbosity_interval
|
||||||
|
|
||||||
|
def find_in_interval(self, n):
|
||||||
|
interval = 0
|
||||||
|
for cl in self.cum_cycles[1:]:
|
||||||
|
if n <= cl:
|
||||||
|
return interval
|
||||||
|
interval += 1
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
cycle = self.find_in_interval(n)
|
||||||
|
n = n - self.cum_cycles[cycle]
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
|
f"current cycle {cycle}")
|
||||||
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||||
|
t = min(t, 1.0)
|
||||||
|
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||||
|
1 + np.cos(t * np.pi))
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
|
||||||
|
def __call__(self, n, **kwargs):
|
||||||
|
return self.schedule(n, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
cycle = self.find_in_interval(n)
|
||||||
|
n = n - self.cum_cycles[cycle]
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
|
f"current cycle {cycle}")
|
||||||
|
|
||||||
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,443 @@
|
||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
class VQModel(pl.LightningModule):
|
||||||
|
def __init__(self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
n_embed,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key="image",
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
batch_resize_range=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
lr_g_factor=1.0,
|
||||||
|
remap=None,
|
||||||
|
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||||
|
use_ema=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.n_embed = n_embed
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||||
|
remap=remap,
|
||||||
|
sane_index_shape=sane_index_shape)
|
||||||
|
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels)==int
|
||||||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
self.batch_resize_range = batch_resize_range
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema = LitEma(self)
|
||||||
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def ema_scope(self, context=None):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.store(self.parameters())
|
||||||
|
self.model_ema.copy_to(self)
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Switched to EMA weights")
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.restore(self.parameters())
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema(self)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
return quant, emb_loss, info
|
||||||
|
|
||||||
|
def encode_to_prequant(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, quant):
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def decode_code(self, code_b):
|
||||||
|
quant_b = self.quantize.embed_code(code_b)
|
||||||
|
dec = self.decode(quant_b)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, return_pred_indices=False):
|
||||||
|
quant, diff, (_,_,ind) = self.encode(input)
|
||||||
|
dec = self.decode(quant)
|
||||||
|
if return_pred_indices:
|
||||||
|
return dec, diff, ind
|
||||||
|
return dec, diff
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
lower_size = self.batch_resize_range[0]
|
||||||
|
upper_size = self.batch_resize_range[1]
|
||||||
|
if self.global_step <= 4:
|
||||||
|
# do the first few batches with max size to avoid later oom
|
||||||
|
new_resize = upper_size
|
||||||
|
else:
|
||||||
|
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||||
|
if new_resize != x.shape[2]:
|
||||||
|
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||||
|
x = x.detach()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
# https://github.com/pytorch/pytorch/issues/37142
|
||||||
|
# try not to fool the heuristics
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# autoencode
|
||||||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train",
|
||||||
|
predicted_indices=ind)
|
||||||
|
|
||||||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
|
with self.ema_scope():
|
||||||
|
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split="val"+suffix,
|
||||||
|
predicted_indices=ind
|
||||||
|
)
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split="val"+suffix,
|
||||||
|
predicted_indices=ind
|
||||||
|
)
|
||||||
|
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||||
|
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
self.log(f"val{suffix}/aeloss", aeloss,
|
||||||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||||
|
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr_d = self.learning_rate
|
||||||
|
lr_g = self.lr_g_factor*self.learning_rate
|
||||||
|
print("lr_d", lr_d)
|
||||||
|
print("lr_g", lr_g)
|
||||||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||||
|
list(self.decoder.parameters())+
|
||||||
|
list(self.quantize.parameters())+
|
||||||
|
list(self.quant_conv.parameters())+
|
||||||
|
list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr_g, betas=(0.5, 0.9))
|
||||||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||||
|
lr=lr_d, betas=(0.5, 0.9))
|
||||||
|
|
||||||
|
if self.scheduler_config is not None:
|
||||||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
|
|
||||||
|
print("Setting up LambdaLR scheduler...")
|
||||||
|
scheduler = [
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return [opt_ae, opt_disc], scheduler
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if only_inputs:
|
||||||
|
log["inputs"] = x
|
||||||
|
return log
|
||||||
|
xrec, _ = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log["inputs"] = x
|
||||||
|
log["reconstructions"] = xrec
|
||||||
|
if plot_ema:
|
||||||
|
with self.ema_scope():
|
||||||
|
xrec_ema, _ = self(x)
|
||||||
|
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||||
|
log["reconstructions_ema"] = xrec_ema
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == "segmentation"
|
||||||
|
if not hasattr(self, "colorize"):
|
||||||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VQModelInterface(VQModel):
|
||||||
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
|
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, h, force_not_quantize=False):
|
||||||
|
# also go through quantization layer
|
||||||
|
if not force_not_quantize:
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
else:
|
||||||
|
quant = h
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKL(pl.LightningModule):
|
||||||
|
def __init__(self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key="image",
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
assert ddconfig["double_z"]
|
||||||
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels)==int
|
||||||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
self.load_state_dict(sd, strict=False)
|
||||||
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
dec = self.decoder(z)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, sample_posterior=True):
|
||||||
|
posterior = self.encode(input)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample()
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z)
|
||||||
|
return dec, posterior
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
inputs = self.get_input(batch, self.image_key)
|
||||||
|
reconstructions, posterior = self(inputs)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# train encoder+decoder+logvar
|
||||||
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# train the discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
|
||||||
|
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
inputs = self.get_input(batch, self.image_key)
|
||||||
|
reconstructions, posterior = self(inputs)
|
||||||
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="val")
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="val")
|
||||||
|
|
||||||
|
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr = self.learning_rate
|
||||||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||||
|
list(self.decoder.parameters())+
|
||||||
|
list(self.quant_conv.parameters())+
|
||||||
|
list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr, betas=(0.5, 0.9))
|
||||||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||||
|
lr=lr, betas=(0.5, 0.9))
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if not only_inputs:
|
||||||
|
xrec, posterior = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||||
|
log["reconstructions"] = xrec
|
||||||
|
log["inputs"] = x
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == "segmentation"
|
||||||
|
if not hasattr(self, "colorize"):
|
||||||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityFirstStage(torch.nn.Module):
|
||||||
|
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||||
|
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def encode(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def quantize(self, x, *args, **kwargs):
|
||||||
|
if self.vq_interface:
|
||||||
|
return x, None, [None, None, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return x
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,267 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
from copy import deepcopy
|
||||||
|
from einops import rearrange
|
||||||
|
from glob import glob
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
||||||
|
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||||
|
|
||||||
|
__models__ = {
|
||||||
|
'class_label': EncoderUNetModel,
|
||||||
|
'segmentation': UNetModel
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def disabled_train(self, mode=True):
|
||||||
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
|
does not change anymore."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
diffusion_path,
|
||||||
|
num_classes,
|
||||||
|
ckpt_path=None,
|
||||||
|
pool='attention',
|
||||||
|
label_key=None,
|
||||||
|
diffusion_ckpt_path=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
weight_decay=1.e-2,
|
||||||
|
log_steps=10,
|
||||||
|
monitor='val/loss',
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
# get latest config of diffusion model
|
||||||
|
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
||||||
|
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||||
|
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||||
|
self.load_diffusion()
|
||||||
|
|
||||||
|
self.monitor = monitor
|
||||||
|
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||||
|
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||||
|
self.log_steps = log_steps
|
||||||
|
|
||||||
|
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
||||||
|
else self.diffusion_model.cond_stage_key
|
||||||
|
|
||||||
|
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
||||||
|
|
||||||
|
if self.label_key not in __models__:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.load_classifier(ckpt_path, pool)
|
||||||
|
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.use_scheduler = self.scheduler_config is not None
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||||
|
sd = torch.load(path, map_location="cpu")
|
||||||
|
if "state_dict" in list(sd.keys()):
|
||||||
|
sd = sd["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||||
|
sd, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def load_diffusion(self):
|
||||||
|
model = instantiate_from_config(self.diffusion_config)
|
||||||
|
self.diffusion_model = model.eval()
|
||||||
|
self.diffusion_model.train = disabled_train
|
||||||
|
for param in self.diffusion_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def load_classifier(self, ckpt_path, pool):
|
||||||
|
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||||
|
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
||||||
|
model_config.out_channels = self.num_classes
|
||||||
|
if self.label_key == 'class_label':
|
||||||
|
model_config.pool = pool
|
||||||
|
|
||||||
|
self.model = __models__[self.label_key](**model_config)
|
||||||
|
if ckpt_path is not None:
|
||||||
|
print('#####################################################################')
|
||||||
|
print(f'load from ckpt "{ckpt_path}"')
|
||||||
|
print('#####################################################################')
|
||||||
|
self.init_from_ckpt(ckpt_path)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_x_noisy(self, x, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x))
|
||||||
|
continuous_sqrt_alpha_cumprod = None
|
||||||
|
if self.diffusion_model.use_continuous_noise:
|
||||||
|
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||||
|
# todo: make sure t+1 is correct here
|
||||||
|
|
||||||
|
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
||||||
|
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
||||||
|
|
||||||
|
def forward(self, x_noisy, t, *args, **kwargs):
|
||||||
|
return self.model(x_noisy, t)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = rearrange(x, 'b h w c -> b c h w')
|
||||||
|
x = x.to(memory_format=torch.contiguous_format).float()
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_conditioning(self, batch, k=None):
|
||||||
|
if k is None:
|
||||||
|
k = self.label_key
|
||||||
|
assert k is not None, 'Needs to provide label key'
|
||||||
|
|
||||||
|
targets = batch[k].to(self.device)
|
||||||
|
|
||||||
|
if self.label_key == 'segmentation':
|
||||||
|
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||||
|
for down in range(self.numd):
|
||||||
|
h, w = targets.shape[-2:]
|
||||||
|
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
||||||
|
|
||||||
|
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||||
|
|
||||||
|
return targets
|
||||||
|
|
||||||
|
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||||
|
_, top_ks = torch.topk(logits, k, dim=1)
|
||||||
|
if reduction == "mean":
|
||||||
|
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||||
|
elif reduction == "none":
|
||||||
|
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||||
|
|
||||||
|
def on_train_epoch_start(self):
|
||||||
|
# save some memory
|
||||||
|
self.diffusion_model.model.to('cpu')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def write_logs(self, loss, logits, targets):
|
||||||
|
log_prefix = 'train' if self.training else 'val'
|
||||||
|
log = {}
|
||||||
|
log[f"{log_prefix}/loss"] = loss.mean()
|
||||||
|
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||||
|
logits, targets, k=1, reduction="mean"
|
||||||
|
)
|
||||||
|
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||||
|
logits, targets, k=5, reduction="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
||||||
|
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||||
|
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
||||||
|
lr = self.optimizers().param_groups[0]['lr']
|
||||||
|
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
||||||
|
|
||||||
|
def shared_step(self, batch, t=None):
|
||||||
|
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
||||||
|
targets = self.get_conditioning(batch)
|
||||||
|
if targets.dim() == 4:
|
||||||
|
targets = targets.argmax(dim=1)
|
||||||
|
if t is None:
|
||||||
|
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
||||||
|
else:
|
||||||
|
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||||
|
x_noisy = self.get_x_noisy(x, t)
|
||||||
|
logits = self(x_noisy, t)
|
||||||
|
|
||||||
|
loss = F.cross_entropy(logits, targets, reduction='none')
|
||||||
|
|
||||||
|
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||||
|
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss, logits, x_noisy, targets
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
loss, *_ = self.shared_step(batch)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def reset_noise_accs(self):
|
||||||
|
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
||||||
|
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
||||||
|
|
||||||
|
def on_validation_start(self):
|
||||||
|
self.reset_noise_accs()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
loss, *_ = self.shared_step(batch)
|
||||||
|
|
||||||
|
for t in self.noisy_acc:
|
||||||
|
_, logits, _, targets = self.shared_step(batch, t)
|
||||||
|
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
||||||
|
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||||
|
|
||||||
|
if self.use_scheduler:
|
||||||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
|
|
||||||
|
print("Setting up LambdaLR scheduler...")
|
||||||
|
scheduler = [
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
}]
|
||||||
|
return [optimizer], scheduler
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||||
|
log['inputs'] = x
|
||||||
|
|
||||||
|
y = self.get_conditioning(batch)
|
||||||
|
|
||||||
|
if self.label_key == 'class_label':
|
||||||
|
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||||
|
log['labels'] = y
|
||||||
|
|
||||||
|
if ismap(y):
|
||||||
|
log['labels'] = self.diffusion_model.to_rgb(y)
|
||||||
|
|
||||||
|
for step in range(self.log_steps):
|
||||||
|
current_time = step * self.log_time_interval
|
||||||
|
|
||||||
|
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||||
|
|
||||||
|
log[f'inputs@t{current_time}'] = x_noisy
|
||||||
|
|
||||||
|
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||||
|
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||||
|
|
||||||
|
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
||||||
|
|
||||||
|
for key in log:
|
||||||
|
log[key] = log[key][:N]
|
||||||
|
|
||||||
|
return log
|
|
@ -0,0 +1,241 @@
|
||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
||||||
|
extract_into_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device("cuda"):
|
||||||
|
attr = attr.to(torch.device("cuda"))
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,verbose=verbose)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||||
|
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
|
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||||
|
|
||||||
|
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_sampling(self, cond, shape,
|
||||||
|
x_T=None, ddim_use_original_steps=False,
|
||||||
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
elif timesteps is not None and not ddim_use_original_steps:
|
||||||
|
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||||
|
timesteps = self.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
|
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||||
|
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
|
img, pred_x0 = outs
|
||||||
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
return img, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
|
# fast, but does not allow for exact reconstruction
|
||||||
|
# t serves as an index to gather the correct alphas
|
||||||
|
if use_original_steps:
|
||||||
|
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||||
|
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||||
|
else:
|
||||||
|
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||||
|
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||||
|
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||||
|
use_original_steps=False):
|
||||||
|
|
||||||
|
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||||
|
timesteps = timesteps[:t_start]
|
||||||
|
|
||||||
|
time_range = np.flip(timesteps)
|
||||||
|
total_steps = timesteps.shape[0]
|
||||||
|
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||||
|
x_dec = x_latent
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||||
|
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
|
return x_dec
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,236 @@
|
||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
|
||||||
|
|
||||||
|
class PLMSSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device("cuda"):
|
||||||
|
attr = attr.to(torch.device("cuda"))
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
if ddim_eta != 0:
|
||||||
|
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,verbose=verbose)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||||
|
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
|
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond, shape,
|
||||||
|
x_T=None, ddim_use_original_steps=False,
|
||||||
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
elif timesteps is not None and not ddim_use_original_steps:
|
||||||
|
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||||
|
timesteps = self.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
|
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||||
|
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
old_eps=old_eps, t_next=ts_next)
|
||||||
|
img, pred_x0, e_t = outs
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
return img, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,261 @@
|
||||||
|
from inspect import isfunction
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn, einsum
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def uniq(arr):
|
||||||
|
return{el: True for el in arr}.keys()
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
|
|
||||||
|
def max_neg_value(t):
|
||||||
|
return -torch.finfo(t.dtype).max
|
||||||
|
|
||||||
|
|
||||||
|
def init_(tensor):
|
||||||
|
dim = tensor.shape[-1]
|
||||||
|
std = 1 / math.sqrt(dim)
|
||||||
|
tensor.uniform_(-std, std)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
# feedforward
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
|
return x * F.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = default(dim_out, dim)
|
||||||
|
project_in = nn.Sequential(
|
||||||
|
nn.Linear(dim, inner_dim),
|
||||||
|
nn.GELU()
|
||||||
|
) if not glu else GEGLU(dim, inner_dim)
|
||||||
|
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
project_in,
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(inner_dim, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
"""
|
||||||
|
Zero out the parameters of a module and return it.
|
||||||
|
"""
|
||||||
|
for p in module.parameters():
|
||||||
|
p.detach().zero_()
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels):
|
||||||
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttention(nn.Module):
|
||||||
|
def __init__(self, dim, heads=4, dim_head=32):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
hidden_dim = dim_head * heads
|
||||||
|
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||||
|
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
qkv = self.to_qkv(x)
|
||||||
|
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||||
|
k = k.softmax(dim=-1)
|
||||||
|
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||||
|
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||||
|
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialSelfAttention(nn.Module):
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = Normalize(in_channels)
|
||||||
|
self.q = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.k = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.v = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b,c,h,w = q.shape
|
||||||
|
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||||
|
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||||
|
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||||
|
|
||||||
|
w_ = w_ * (int(c)**(-0.5))
|
||||||
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||||
|
w_ = rearrange(w_, 'b i j -> b j i')
|
||||||
|
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||||
|
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||||
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
|
return x+h_
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
context_dim = default(context_dim, query_dim)
|
||||||
|
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
self.heads = heads
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||||
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, query_dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||||
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||||
|
super().__init__()
|
||||||
|
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
||||||
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||||
|
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||||
|
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
|
def forward(self, x, context=None):
|
||||||
|
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
|
def _forward(self, x, context=None):
|
||||||
|
x = self.attn1(self.norm1(x)) + x
|
||||||
|
x = self.attn2(self.norm2(x), context=context) + x
|
||||||
|
x = self.ff(self.norm3(x)) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer block for image-like data.
|
||||||
|
First, project the input (aka embedding)
|
||||||
|
and reshape to b, t, d.
|
||||||
|
Then apply standard transformer action.
|
||||||
|
Finally, reshape to image
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
|
depth=1, dropout=0., context_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
inner_dim = n_heads * d_head
|
||||||
|
self.norm = Normalize(in_channels)
|
||||||
|
|
||||||
|
self.proj_in = nn.Conv2d(in_channels,
|
||||||
|
inner_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||||
|
for d in range(depth)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0))
|
||||||
|
|
||||||
|
def forward(self, x, context=None):
|
||||||
|
# note: if no context is given, cross-attention defaults to self-attention
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.proj_in(x)
|
||||||
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
x = block(x, context=context)
|
||||||
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||||
|
x = self.proj_out(x)
|
||||||
|
return x + x_in
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,835 @@
|
||||||
|
# pytorch_diffusion + derived encoder decoder
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.modules.attention import LinearAttention
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
|
"""
|
||||||
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
|
From Fairseq.
|
||||||
|
Build sinusoidal embeddings.
|
||||||
|
This matches the implementation in tensor2tensor, but differs slightly
|
||||||
|
from the description in Section 3.5 of "Attention Is All You Need".
|
||||||
|
"""
|
||||||
|
assert len(timesteps.shape) == 1
|
||||||
|
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||||
|
emb = emb.to(device=timesteps.device)
|
||||||
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
|
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def nonlinearity(x):
|
||||||
|
# swish
|
||||||
|
return x*torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels, with_conv):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
if self.with_conv:
|
||||||
|
self.conv = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, in_channels, with_conv):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
if self.with_conv:
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.with_conv:
|
||||||
|
pad = (0,1,0,1)
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
else:
|
||||||
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||||
|
dropout, temb_channels=512):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
|
self.norm1 = Normalize(in_channels)
|
||||||
|
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
if temb_channels > 0:
|
||||||
|
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||||
|
out_channels)
|
||||||
|
self.norm2 = Normalize(out_channels)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
|
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
def forward(self, x, temb):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.dropout(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
x = self.conv_shortcut(x)
|
||||||
|
else:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x+h
|
||||||
|
|
||||||
|
|
||||||
|
class LinAttnBlock(LinearAttention):
|
||||||
|
"""to match AttnBlock usage"""
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = Normalize(in_channels)
|
||||||
|
self.q = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.k = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.v = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b,c,h,w = q.shape
|
||||||
|
q = q.reshape(b,c,h*w)
|
||||||
|
q = q.permute(0,2,1) # b,hw,c
|
||||||
|
k = k.reshape(b,c,h*w) # b,c,hw
|
||||||
|
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w_ = w_ * (int(c)**(-0.5))
|
||||||
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v = v.reshape(b,c,h*w)
|
||||||
|
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||||
|
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
h_ = h_.reshape(b,c,h,w)
|
||||||
|
|
||||||
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
|
return x+h_
|
||||||
|
|
||||||
|
|
||||||
|
def make_attn(in_channels, attn_type="vanilla"):
|
||||||
|
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
|
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
|
if attn_type == "vanilla":
|
||||||
|
return AttnBlock(in_channels)
|
||||||
|
elif attn_type == "none":
|
||||||
|
return nn.Identity(in_channels)
|
||||||
|
else:
|
||||||
|
return LinAttnBlock(in_channels)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
|
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||||
|
super().__init__()
|
||||||
|
if use_linear_attn: attn_type = "linear"
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = self.ch*4
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.use_timestep = use_timestep
|
||||||
|
if self.use_timestep:
|
||||||
|
# timestep embedding
|
||||||
|
self.temb = nn.Module()
|
||||||
|
self.temb.dense = nn.ModuleList([
|
||||||
|
torch.nn.Linear(self.ch,
|
||||||
|
self.temb_ch),
|
||||||
|
torch.nn.Linear(self.temb_ch,
|
||||||
|
self.temb_ch),
|
||||||
|
])
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||||
|
self.ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,)+tuple(ch_mult)
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch*in_ch_mult[i_level]
|
||||||
|
block_out = ch*ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout))
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions-1:
|
||||||
|
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout)
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch*ch_mult[i_level]
|
||||||
|
skip_in = ch*ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks+1):
|
||||||
|
if i_block == self.num_res_blocks:
|
||||||
|
skip_in = ch*in_ch_mult[i_level]
|
||||||
|
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout))
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in)
|
||||||
|
self.conv_out = torch.nn.Conv2d(block_in,
|
||||||
|
out_ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
def forward(self, x, t=None, context=None):
|
||||||
|
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||||
|
if context is not None:
|
||||||
|
# assume aligned context, cat along channel axis
|
||||||
|
x = torch.cat((x, context), dim=1)
|
||||||
|
if self.use_timestep:
|
||||||
|
# timestep embedding
|
||||||
|
assert t is not None
|
||||||
|
temb = get_timestep_embedding(t, self.ch)
|
||||||
|
temb = self.temb.dense[0](temb)
|
||||||
|
temb = nonlinearity(temb)
|
||||||
|
temb = self.temb.dense[1](temb)
|
||||||
|
else:
|
||||||
|
temb = None
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
hs = [self.conv_in(x)]
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
hs.append(h)
|
||||||
|
if i_level != self.num_resolutions-1:
|
||||||
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
|
h = self.mid.block_1(h, temb)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h, temb)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks+1):
|
||||||
|
h = self.up[i_level].block[i_block](
|
||||||
|
torch.cat([h, hs.pop()], dim=1), temb)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h)
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.conv_out.weight
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
|
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||||
|
**ignore_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
if use_linear_attn: attn_type = "linear"
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||||
|
self.ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,)+tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch*in_ch_mult[i_level]
|
||||||
|
block_out = ch*ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout))
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions-1:
|
||||||
|
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout)
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in)
|
||||||
|
self.conv_out = torch.nn.Conv2d(block_in,
|
||||||
|
2*z_channels if double_z else z_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# timestep embedding
|
||||||
|
temb = None
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
hs = [self.conv_in(x)]
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
hs.append(h)
|
||||||
|
if i_level != self.num_resolutions-1:
|
||||||
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
|
h = self.mid.block_1(h, temb)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h, temb)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
|
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||||
|
attn_type="vanilla", **ignorekwargs):
|
||||||
|
super().__init__()
|
||||||
|
if use_linear_attn: attn_type = "linear"
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.give_pre_end = give_pre_end
|
||||||
|
self.tanh_out = tanh_out
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
in_ch_mult = (1,)+tuple(ch_mult)
|
||||||
|
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||||
|
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||||
|
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||||
|
print("Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||||
|
block_in,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout)
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch*ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks+1):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout))
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in)
|
||||||
|
self.conv_out = torch.nn.Conv2d(block_in,
|
||||||
|
out_ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
#assert z.shape[1:] == self.z_shape[1:]
|
||||||
|
self.last_z_shape = z.shape
|
||||||
|
|
||||||
|
# timestep embedding
|
||||||
|
temb = None
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid.block_1(h, temb)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h, temb)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks+1):
|
||||||
|
h = self.up[i_level].block[i_block](h, temb)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h)
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
if self.give_pre_end:
|
||||||
|
return h
|
||||||
|
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
if self.tanh_out:
|
||||||
|
h = torch.tanh(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleDecoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
||||||
|
ResnetBlock(in_channels=in_channels,
|
||||||
|
out_channels=2 * in_channels,
|
||||||
|
temb_channels=0, dropout=0.0),
|
||||||
|
ResnetBlock(in_channels=2 * in_channels,
|
||||||
|
out_channels=4 * in_channels,
|
||||||
|
temb_channels=0, dropout=0.0),
|
||||||
|
ResnetBlock(in_channels=4 * in_channels,
|
||||||
|
out_channels=2 * in_channels,
|
||||||
|
temb_channels=0, dropout=0.0),
|
||||||
|
nn.Conv2d(2*in_channels, in_channels, 1),
|
||||||
|
Upsample(in_channels, with_conv=True)])
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(in_channels)
|
||||||
|
self.conv_out = torch.nn.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i, layer in enumerate(self.model):
|
||||||
|
if i in [1,2,3]:
|
||||||
|
x = layer(x, None)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
h = self.norm_out(x)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
x = self.conv_out(h)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UpsampleDecoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
||||||
|
ch_mult=(2,2), dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
# upsampling
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
block_in = in_channels
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.res_blocks = nn.ModuleList()
|
||||||
|
self.upsample_blocks = nn.ModuleList()
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
res_block = []
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
res_block.append(ResnetBlock(in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout))
|
||||||
|
block_in = block_out
|
||||||
|
self.res_blocks.append(nn.ModuleList(res_block))
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
self.upsample_blocks.append(Upsample(block_in, True))
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in)
|
||||||
|
self.conv_out = torch.nn.Conv2d(block_in,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# upsampling
|
||||||
|
h = x
|
||||||
|
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.res_blocks[i_level][i_block](h, None)
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
h = self.upsample_blocks[k](h)
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class LatentRescaler(nn.Module):
|
||||||
|
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||||
|
super().__init__()
|
||||||
|
# residual block, interpolate, residual block
|
||||||
|
self.factor = factor
|
||||||
|
self.conv_in = nn.Conv2d(in_channels,
|
||||||
|
mid_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||||
|
out_channels=mid_channels,
|
||||||
|
temb_channels=0,
|
||||||
|
dropout=0.0) for _ in range(depth)])
|
||||||
|
self.attn = AttnBlock(mid_channels)
|
||||||
|
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||||
|
out_channels=mid_channels,
|
||||||
|
temb_channels=0,
|
||||||
|
dropout=0.0) for _ in range(depth)])
|
||||||
|
|
||||||
|
self.conv_out = nn.Conv2d(mid_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_in(x)
|
||||||
|
for block in self.res_block1:
|
||||||
|
x = block(x, None)
|
||||||
|
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
||||||
|
x = self.attn(x)
|
||||||
|
for block in self.res_block2:
|
||||||
|
x = block(x, None)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MergedRescaleEncoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
||||||
|
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
||||||
|
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
||||||
|
super().__init__()
|
||||||
|
intermediate_chn = ch * ch_mult[-1]
|
||||||
|
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
||||||
|
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
||||||
|
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
||||||
|
out_ch=None)
|
||||||
|
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
||||||
|
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.encoder(x)
|
||||||
|
x = self.rescaler(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MergedRescaleDecoder(nn.Module):
|
||||||
|
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
||||||
|
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
||||||
|
super().__init__()
|
||||||
|
tmp_chn = z_channels*ch_mult[-1]
|
||||||
|
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
||||||
|
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
||||||
|
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
||||||
|
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
||||||
|
out_channels=tmp_chn, depth=rescale_module_depth)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.rescaler(x)
|
||||||
|
x = self.decoder(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsampler(nn.Module):
|
||||||
|
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||||
|
super().__init__()
|
||||||
|
assert out_size >= in_size
|
||||||
|
num_blocks = int(np.log2(out_size//in_size))+1
|
||||||
|
factor_up = 1.+ (out_size % in_size)
|
||||||
|
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||||
|
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||||
|
out_channels=in_channels)
|
||||||
|
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||||
|
attn_resolutions=[], in_channels=None, ch=in_channels,
|
||||||
|
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.rescaler(x)
|
||||||
|
x = self.decoder(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Resize(nn.Module):
|
||||||
|
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = learned
|
||||||
|
self.mode = mode
|
||||||
|
if self.with_conv:
|
||||||
|
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||||
|
raise NotImplementedError()
|
||||||
|
assert in_channels is not None
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
def forward(self, x, scale_factor=1.0):
|
||||||
|
if scale_factor==1.0:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class FirstStagePostProcessor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, ch_mult:list, in_channels,
|
||||||
|
pretrained_model:nn.Module=None,
|
||||||
|
reshape=False,
|
||||||
|
n_channels=None,
|
||||||
|
dropout=0.,
|
||||||
|
pretrained_config=None):
|
||||||
|
super().__init__()
|
||||||
|
if pretrained_config is None:
|
||||||
|
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||||
|
self.pretrained_model = pretrained_model
|
||||||
|
else:
|
||||||
|
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||||
|
self.instantiate_pretrained(pretrained_config)
|
||||||
|
|
||||||
|
self.do_reshape = reshape
|
||||||
|
|
||||||
|
if n_channels is None:
|
||||||
|
n_channels = self.pretrained_model.encoder.ch
|
||||||
|
|
||||||
|
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
||||||
|
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
||||||
|
stride=1,padding=1)
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
downs = []
|
||||||
|
ch_in = n_channels
|
||||||
|
for m in ch_mult:
|
||||||
|
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
||||||
|
ch_in = m * n_channels
|
||||||
|
downs.append(Downsample(ch_in, with_conv=False))
|
||||||
|
|
||||||
|
self.model = nn.ModuleList(blocks)
|
||||||
|
self.downsampler = nn.ModuleList(downs)
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_pretrained(self, config):
|
||||||
|
model = instantiate_from_config(config)
|
||||||
|
self.pretrained_model = model.eval()
|
||||||
|
# self.pretrained_model.train = False
|
||||||
|
for param in self.pretrained_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode_with_pretrained(self,x):
|
||||||
|
c = self.pretrained_model.encode(x)
|
||||||
|
if isinstance(c, DiagonalGaussianDistribution):
|
||||||
|
c = c.mode()
|
||||||
|
return c
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
z_fs = self.encode_with_pretrained(x)
|
||||||
|
z = self.proj_norm(z_fs)
|
||||||
|
z = self.proj(z)
|
||||||
|
z = nonlinearity(z)
|
||||||
|
|
||||||
|
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||||
|
z = submodel(z,temb=None)
|
||||||
|
z = downmodel(z)
|
||||||
|
|
||||||
|
if self.do_reshape:
|
||||||
|
z = rearrange(z,'b c h w -> b (h w) c')
|
||||||
|
return z
|
||||||
|
|
|
@ -0,0 +1,961 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from functools import partial
|
||||||
|
import math
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import (
|
||||||
|
checkpoint,
|
||||||
|
conv_nd,
|
||||||
|
linear,
|
||||||
|
avg_pool_nd,
|
||||||
|
zero_module,
|
||||||
|
normalization,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
from ldm.modules.attention import SpatialTransformer
|
||||||
|
|
||||||
|
|
||||||
|
# dummy replace
|
||||||
|
def convert_module_to_f16(x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def convert_module_to_f32(x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
## go
|
||||||
|
class AttentionPool2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spacial_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads_channels: int,
|
||||||
|
output_dim: int = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||||
|
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||||
|
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||||
|
self.num_heads = embed_dim // num_heads_channels
|
||||||
|
self.attention = QKVAttention(self.num_heads)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, *_spatial = x.shape
|
||||||
|
x = x.reshape(b, c, -1) # NC(HW)
|
||||||
|
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||||
|
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||||
|
x = self.qkv_proj(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
x = self.c_proj(x)
|
||||||
|
return x[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Any module where forward() takes timestep embeddings as a second argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, x, emb):
|
||||||
|
"""
|
||||||
|
Apply the module to `x` given `emb` timestep embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
|
"""
|
||||||
|
A sequential module that passes timestep embeddings to the children that
|
||||||
|
support it as an extra input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x, emb, context=None):
|
||||||
|
for layer in self:
|
||||||
|
if isinstance(layer, TimestepBlock):
|
||||||
|
x = layer(x, emb)
|
||||||
|
elif isinstance(layer, SpatialTransformer):
|
||||||
|
x = layer(x, context)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
"""
|
||||||
|
An upsampling layer with an optional convolution.
|
||||||
|
:param channels: channels in the inputs and outputs.
|
||||||
|
:param use_conv: a bool determining if a convolution is applied.
|
||||||
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||||
|
upsampling occurs in the inner-two dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.dims = dims
|
||||||
|
if use_conv:
|
||||||
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.channels
|
||||||
|
if self.dims == 3:
|
||||||
|
x = F.interpolate(
|
||||||
|
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
|
if self.use_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TransposedUpsample(nn.Module):
|
||||||
|
'Learned 2x upsampling without padding'
|
||||||
|
def __init__(self, channels, out_channels=None, ks=5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
|
||||||
|
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
return self.up(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer with an optional convolution.
|
||||||
|
:param channels: channels in the inputs and outputs.
|
||||||
|
:param use_conv: a bool determining if a convolution is applied.
|
||||||
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||||
|
downsampling occurs in the inner-two dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.dims = dims
|
||||||
|
stride = 2 if dims != 3 else (1, 2, 2)
|
||||||
|
if use_conv:
|
||||||
|
self.op = conv_nd(
|
||||||
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.channels == self.out_channels
|
||||||
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.channels
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(TimestepBlock):
|
||||||
|
"""
|
||||||
|
A residual block that can optionally change the number of channels.
|
||||||
|
:param channels: the number of input channels.
|
||||||
|
:param emb_channels: the number of timestep embedding channels.
|
||||||
|
:param dropout: the rate of dropout.
|
||||||
|
:param out_channels: if specified, the number of out channels.
|
||||||
|
:param use_conv: if True and out_channels is specified, use a spatial
|
||||||
|
convolution instead of a smaller 1x1 convolution to change the
|
||||||
|
channels in the skip connection.
|
||||||
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||||
|
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||||
|
:param up: if True, use this block for upsampling.
|
||||||
|
:param down: if True, use this block for downsampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
emb_channels,
|
||||||
|
dropout,
|
||||||
|
out_channels=None,
|
||||||
|
use_conv=False,
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
dims=2,
|
||||||
|
use_checkpoint=False,
|
||||||
|
up=False,
|
||||||
|
down=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.emb_channels = emb_channels
|
||||||
|
self.dropout = dropout
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
|
self.in_layers = nn.Sequential(
|
||||||
|
normalization(channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.updown = up or down
|
||||||
|
|
||||||
|
if up:
|
||||||
|
self.h_upd = Upsample(channels, False, dims)
|
||||||
|
self.x_upd = Upsample(channels, False, dims)
|
||||||
|
elif down:
|
||||||
|
self.h_upd = Downsample(channels, False, dims)
|
||||||
|
self.x_upd = Downsample(channels, False, dims)
|
||||||
|
else:
|
||||||
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
|
self.emb_layers = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(
|
||||||
|
emb_channels,
|
||||||
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.out_layers = nn.Sequential(
|
||||||
|
normalization(self.out_channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
zero_module(
|
||||||
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.out_channels == channels:
|
||||||
|
self.skip_connection = nn.Identity()
|
||||||
|
elif use_conv:
|
||||||
|
self.skip_connection = conv_nd(
|
||||||
|
dims, channels, self.out_channels, 3, padding=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
"""
|
||||||
|
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||||
|
:param x: an [N x C x ...] Tensor of features.
|
||||||
|
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||||
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
|
"""
|
||||||
|
return checkpoint(
|
||||||
|
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _forward(self, x, emb):
|
||||||
|
if self.updown:
|
||||||
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||||
|
h = in_rest(x)
|
||||||
|
h = self.h_upd(h)
|
||||||
|
x = self.x_upd(x)
|
||||||
|
h = in_conv(h)
|
||||||
|
else:
|
||||||
|
h = self.in_layers(x)
|
||||||
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||||
|
while len(emb_out.shape) < len(h.shape):
|
||||||
|
emb_out = emb_out[..., None]
|
||||||
|
if self.use_scale_shift_norm:
|
||||||
|
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||||
|
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||||
|
h = out_norm(h) * (1 + scale) + shift
|
||||||
|
h = out_rest(h)
|
||||||
|
else:
|
||||||
|
h = h + emb_out
|
||||||
|
h = self.out_layers(h)
|
||||||
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
An attention block that allows spatial positions to attend to each other.
|
||||||
|
Originally ported from here, but adapted to the N-d case.
|
||||||
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
num_heads=1,
|
||||||
|
num_head_channels=-1,
|
||||||
|
use_checkpoint=False,
|
||||||
|
use_new_attention_order=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
if num_head_channels == -1:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
channels % num_head_channels == 0
|
||||||
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||||
|
self.num_heads = channels // num_head_channels
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.norm = normalization(channels)
|
||||||
|
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||||
|
if use_new_attention_order:
|
||||||
|
# split qkv before split heads
|
||||||
|
self.attention = QKVAttention(self.num_heads)
|
||||||
|
else:
|
||||||
|
# split heads before split qkv
|
||||||
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||||
|
|
||||||
|
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||||
|
#return pt_checkpoint(self._forward, x) # pytorch
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
b, c, *spatial = x.shape
|
||||||
|
x = x.reshape(b, c, -1)
|
||||||
|
qkv = self.qkv(self.norm(x))
|
||||||
|
h = self.attention(qkv)
|
||||||
|
h = self.proj_out(h)
|
||||||
|
return (x + h).reshape(b, c, *spatial)
|
||||||
|
|
||||||
|
|
||||||
|
def count_flops_attn(model, _x, y):
|
||||||
|
"""
|
||||||
|
A counter for the `thop` package to count the operations in an
|
||||||
|
attention operation.
|
||||||
|
Meant to be used like:
|
||||||
|
macs, params = thop.profile(
|
||||||
|
model,
|
||||||
|
inputs=(inputs, timestamps),
|
||||||
|
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
b, c, *spatial = y[0].shape
|
||||||
|
num_spatial = int(np.prod(spatial))
|
||||||
|
# We perform two matmuls with the same number of ops.
|
||||||
|
# The first computes the weight matrix, the second computes
|
||||||
|
# the combination of the value vectors.
|
||||||
|
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||||
|
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||||
|
|
||||||
|
|
||||||
|
class QKVAttentionLegacy(nn.Module):
|
||||||
|
"""
|
||||||
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def forward(self, qkv):
|
||||||
|
"""
|
||||||
|
Apply QKV attention.
|
||||||
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||||
|
:return: an [N x (H * C) x T] tensor after attention.
|
||||||
|
"""
|
||||||
|
bs, width, length = qkv.shape
|
||||||
|
assert width % (3 * self.n_heads) == 0
|
||||||
|
ch = width // (3 * self.n_heads)
|
||||||
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||||
|
weight = th.einsum(
|
||||||
|
"bct,bcs->bts", q * scale, k * scale
|
||||||
|
) # More stable with f16 than dividing afterwards
|
||||||
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
a = th.einsum("bts,bcs->bct", weight, v)
|
||||||
|
return a.reshape(bs, -1, length)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_flops(model, _x, y):
|
||||||
|
return count_flops_attn(model, _x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class QKVAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
A module which performs QKV attention and splits in a different order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def forward(self, qkv):
|
||||||
|
"""
|
||||||
|
Apply QKV attention.
|
||||||
|
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||||
|
:return: an [N x (H * C) x T] tensor after attention.
|
||||||
|
"""
|
||||||
|
bs, width, length = qkv.shape
|
||||||
|
assert width % (3 * self.n_heads) == 0
|
||||||
|
ch = width // (3 * self.n_heads)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||||
|
weight = th.einsum(
|
||||||
|
"bct,bcs->bts",
|
||||||
|
(q * scale).view(bs * self.n_heads, ch, length),
|
||||||
|
(k * scale).view(bs * self.n_heads, ch, length),
|
||||||
|
) # More stable with f16 than dividing afterwards
|
||||||
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||||
|
return a.reshape(bs, -1, length)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_flops(model, _x, y):
|
||||||
|
return count_flops_attn(model, _x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class UNetModel(nn.Module):
|
||||||
|
"""
|
||||||
|
The full UNet model with attention and timestep embedding.
|
||||||
|
:param in_channels: channels in the input Tensor.
|
||||||
|
:param model_channels: base channel count for the model.
|
||||||
|
:param out_channels: channels in the output Tensor.
|
||||||
|
:param num_res_blocks: number of residual blocks per downsample.
|
||||||
|
:param attention_resolutions: a collection of downsample rates at which
|
||||||
|
attention will take place. May be a set, list, or tuple.
|
||||||
|
For example, if this contains 4, then at 4x downsampling, attention
|
||||||
|
will be used.
|
||||||
|
:param dropout: the dropout probability.
|
||||||
|
:param channel_mult: channel multiplier for each level of the UNet.
|
||||||
|
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||||
|
downsampling.
|
||||||
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||||
|
:param num_classes: if specified (as an int), then this model will be
|
||||||
|
class-conditional with `num_classes` classes.
|
||||||
|
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||||
|
:param num_heads: the number of attention heads in each attention layer.
|
||||||
|
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||||
|
a fixed channel width per attention head.
|
||||||
|
:param num_heads_upsample: works with num_heads to set a different number
|
||||||
|
of heads for upsampling. Deprecated.
|
||||||
|
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||||
|
:param resblock_updown: use residual blocks for up/downsampling.
|
||||||
|
:param use_new_attention_order: use a different attention pattern for potentially
|
||||||
|
increased efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size,
|
||||||
|
in_channels,
|
||||||
|
model_channels,
|
||||||
|
out_channels,
|
||||||
|
num_res_blocks,
|
||||||
|
attention_resolutions,
|
||||||
|
dropout=0,
|
||||||
|
channel_mult=(1, 2, 4, 8),
|
||||||
|
conv_resample=True,
|
||||||
|
dims=2,
|
||||||
|
num_classes=None,
|
||||||
|
use_checkpoint=False,
|
||||||
|
use_fp16=False,
|
||||||
|
num_heads=-1,
|
||||||
|
num_head_channels=-1,
|
||||||
|
num_heads_upsample=-1,
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
resblock_updown=False,
|
||||||
|
use_new_attention_order=False,
|
||||||
|
use_spatial_transformer=False, # custom transformer support
|
||||||
|
transformer_depth=1, # custom transformer support
|
||||||
|
context_dim=None, # custom transformer support
|
||||||
|
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||||
|
legacy=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if use_spatial_transformer:
|
||||||
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
|
if context_dim is not None:
|
||||||
|
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||||
|
from omegaconf.listconfig import ListConfig
|
||||||
|
if type(context_dim) == ListConfig:
|
||||||
|
context_dim = list(context_dim)
|
||||||
|
|
||||||
|
if num_heads_upsample == -1:
|
||||||
|
num_heads_upsample = num_heads
|
||||||
|
|
||||||
|
if num_heads == -1:
|
||||||
|
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
|
self.image_size = image_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attention_resolutions = attention_resolutions
|
||||||
|
self.dropout = dropout
|
||||||
|
self.channel_mult = channel_mult
|
||||||
|
self.conv_resample = conv_resample
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.dtype = th.float16 if use_fp16 else th.float32
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_head_channels = num_head_channels
|
||||||
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
self.predict_codebook_ids = n_embed is not None
|
||||||
|
|
||||||
|
time_embed_dim = model_channels * 4
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
linear(model_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.num_classes is not None:
|
||||||
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._feature_size = model_channels
|
||||||
|
input_block_chans = [model_channels]
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=mult * model_channels,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
#num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=dim_head,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
down=True,
|
||||||
|
)
|
||||||
|
if resblock_updown
|
||||||
|
else Downsample(
|
||||||
|
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
ds *= 2
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
#num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=dim_head,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
|
),
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||||
|
for i in range(num_res_blocks + 1):
|
||||||
|
ich = input_block_chans.pop()
|
||||||
|
layers = [
|
||||||
|
ResBlock(
|
||||||
|
ch + ich,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=model_channels * mult,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = model_channels * mult
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
#num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads_upsample,
|
||||||
|
num_head_channels=dim_head,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if level and i == num_res_blocks:
|
||||||
|
out_ch = ch
|
||||||
|
layers.append(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
up=True,
|
||||||
|
)
|
||||||
|
if resblock_updown
|
||||||
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||||
|
)
|
||||||
|
ds //= 2
|
||||||
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||||
|
)
|
||||||
|
if self.predict_codebook_ids:
|
||||||
|
self.id_predictor = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
conv_nd(dims, model_channels, n_embed, 1),
|
||||||
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
|
)
|
||||||
|
|
||||||
|
def convert_to_fp16(self):
|
||||||
|
"""
|
||||||
|
Convert the torso of the model to float16.
|
||||||
|
"""
|
||||||
|
self.input_blocks.apply(convert_module_to_f16)
|
||||||
|
self.middle_block.apply(convert_module_to_f16)
|
||||||
|
self.output_blocks.apply(convert_module_to_f16)
|
||||||
|
|
||||||
|
def convert_to_fp32(self):
|
||||||
|
"""
|
||||||
|
Convert the torso of the model to float32.
|
||||||
|
"""
|
||||||
|
self.input_blocks.apply(convert_module_to_f32)
|
||||||
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
|
self.output_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||||
|
"""
|
||||||
|
Apply the model to an input batch.
|
||||||
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
|
:param timesteps: a 1-D batch of timesteps.
|
||||||
|
:param context: conditioning plugged in via crossattn
|
||||||
|
:param y: an [N] Tensor of labels, if class-conditional.
|
||||||
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
|
"""
|
||||||
|
assert (y is not None) == (
|
||||||
|
self.num_classes is not None
|
||||||
|
), "must specify y if and only if the model is class-conditional"
|
||||||
|
hs = []
|
||||||
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||||
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
if self.num_classes is not None:
|
||||||
|
assert y.shape == (x.shape[0],)
|
||||||
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
for module in self.input_blocks:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
h = self.middle_block(h, emb, context)
|
||||||
|
for module in self.output_blocks:
|
||||||
|
h = th.cat([h, hs.pop()], dim=1)
|
||||||
|
h = module(h, emb, context)
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
if self.predict_codebook_ids:
|
||||||
|
return self.id_predictor(h)
|
||||||
|
else:
|
||||||
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderUNetModel(nn.Module):
|
||||||
|
"""
|
||||||
|
The half UNet model with attention and timestep embedding.
|
||||||
|
For usage, see UNet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size,
|
||||||
|
in_channels,
|
||||||
|
model_channels,
|
||||||
|
out_channels,
|
||||||
|
num_res_blocks,
|
||||||
|
attention_resolutions,
|
||||||
|
dropout=0,
|
||||||
|
channel_mult=(1, 2, 4, 8),
|
||||||
|
conv_resample=True,
|
||||||
|
dims=2,
|
||||||
|
use_checkpoint=False,
|
||||||
|
use_fp16=False,
|
||||||
|
num_heads=1,
|
||||||
|
num_head_channels=-1,
|
||||||
|
num_heads_upsample=-1,
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
resblock_updown=False,
|
||||||
|
use_new_attention_order=False,
|
||||||
|
pool="adaptive",
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if num_heads_upsample == -1:
|
||||||
|
num_heads_upsample = num_heads
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attention_resolutions = attention_resolutions
|
||||||
|
self.dropout = dropout
|
||||||
|
self.channel_mult = channel_mult
|
||||||
|
self.conv_resample = conv_resample
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.dtype = th.float16 if use_fp16 else th.float32
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_head_channels = num_head_channels
|
||||||
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
|
||||||
|
time_embed_dim = model_channels * 4
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
linear(model_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._feature_size = model_channels
|
||||||
|
input_block_chans = [model_channels]
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=mult * model_channels,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=num_head_channels,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
down=True,
|
||||||
|
)
|
||||||
|
if resblock_updown
|
||||||
|
else Downsample(
|
||||||
|
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
ds *= 2
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=num_head_channels,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
),
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._feature_size += ch
|
||||||
|
self.pool = pool
|
||||||
|
if pool == "adaptive":
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||||
|
nn.Flatten(),
|
||||||
|
)
|
||||||
|
elif pool == "attention":
|
||||||
|
assert num_head_channels != -1
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
AttentionPool2d(
|
||||||
|
(image_size // ds), ch, num_head_channels, out_channels
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif pool == "spatial":
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
nn.Linear(self._feature_size, 2048),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(2048, self.out_channels),
|
||||||
|
)
|
||||||
|
elif pool == "spatial_v2":
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
nn.Linear(self._feature_size, 2048),
|
||||||
|
normalization(2048),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(2048, self.out_channels),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||||
|
|
||||||
|
def convert_to_fp16(self):
|
||||||
|
"""
|
||||||
|
Convert the torso of the model to float16.
|
||||||
|
"""
|
||||||
|
self.input_blocks.apply(convert_module_to_f16)
|
||||||
|
self.middle_block.apply(convert_module_to_f16)
|
||||||
|
|
||||||
|
def convert_to_fp32(self):
|
||||||
|
"""
|
||||||
|
Convert the torso of the model to float32.
|
||||||
|
"""
|
||||||
|
self.input_blocks.apply(convert_module_to_f32)
|
||||||
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps):
|
||||||
|
"""
|
||||||
|
Apply the model to an input batch.
|
||||||
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
|
:param timesteps: a 1-D batch of timesteps.
|
||||||
|
:return: an [N x K] Tensor of outputs.
|
||||||
|
"""
|
||||||
|
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
|
|
||||||
|
results = []
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
for module in self.input_blocks:
|
||||||
|
h = module(h, emb)
|
||||||
|
if self.pool.startswith("spatial"):
|
||||||
|
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||||
|
h = self.middle_block(h, emb)
|
||||||
|
if self.pool.startswith("spatial"):
|
||||||
|
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||||
|
h = th.cat(results, axis=-1)
|
||||||
|
return self.out(h)
|
||||||
|
else:
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
return self.out(h)
|
||||||
|
|
|
@ -0,0 +1,267 @@
|
||||||
|
# adopted from
|
||||||
|
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||||
|
# and
|
||||||
|
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||||
|
# and
|
||||||
|
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||||
|
#
|
||||||
|
# thanks!
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
|
if schedule == "linear":
|
||||||
|
betas = (
|
||||||
|
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||||
|
)
|
||||||
|
|
||||||
|
elif schedule == "cosine":
|
||||||
|
timesteps = (
|
||||||
|
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||||
|
)
|
||||||
|
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||||
|
alphas = torch.cos(alphas).pow(2)
|
||||||
|
alphas = alphas / alphas[0]
|
||||||
|
betas = 1 - alphas[1:] / alphas[:-1]
|
||||||
|
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||||
|
|
||||||
|
elif schedule == "sqrt_linear":
|
||||||
|
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||||
|
elif schedule == "sqrt":
|
||||||
|
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||||
|
else:
|
||||||
|
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||||
|
return betas.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||||
|
if ddim_discr_method == 'uniform':
|
||||||
|
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||||
|
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||||
|
elif ddim_discr_method == 'quad':
|
||||||
|
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||||
|
|
||||||
|
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||||
|
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||||
|
steps_out = ddim_timesteps + 1
|
||||||
|
if verbose:
|
||||||
|
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||||
|
return steps_out
|
||||||
|
|
||||||
|
|
||||||
|
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||||
|
# select alphas for computing the variance schedule
|
||||||
|
alphas = alphacums[ddim_timesteps]
|
||||||
|
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||||
|
|
||||||
|
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||||
|
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||||
|
if verbose:
|
||||||
|
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||||
|
print(f'For the chosen value of eta, which is {eta}, '
|
||||||
|
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||||
|
return sigmas, alphas, alphas_prev
|
||||||
|
|
||||||
|
|
||||||
|
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||||
|
"""
|
||||||
|
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||||
|
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||||
|
:param num_diffusion_timesteps: the number of betas to produce.
|
||||||
|
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||||
|
produces the cumulative product of (1-beta) up to that
|
||||||
|
part of the diffusion process.
|
||||||
|
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||||
|
prevent singularities.
|
||||||
|
"""
|
||||||
|
betas = []
|
||||||
|
for i in range(num_diffusion_timesteps):
|
||||||
|
t1 = i / num_diffusion_timesteps
|
||||||
|
t2 = (i + 1) / num_diffusion_timesteps
|
||||||
|
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||||
|
return np.array(betas)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_into_tensor(a, t, x_shape):
|
||||||
|
b, *_ = t.shape
|
||||||
|
out = a.gather(-1, t)
|
||||||
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint(func, inputs, params, flag):
|
||||||
|
"""
|
||||||
|
Evaluate a function without caching intermediate activations, allowing for
|
||||||
|
reduced memory at the expense of extra compute in the backward pass.
|
||||||
|
:param func: the function to evaluate.
|
||||||
|
:param inputs: the argument sequence to pass to `func`.
|
||||||
|
:param params: a sequence of parameters `func` depends on but does not
|
||||||
|
explicitly take as arguments.
|
||||||
|
:param flag: if False, disable gradient checkpointing.
|
||||||
|
"""
|
||||||
|
if False: # disabled checkpointing to allow requires_grad = False for main model
|
||||||
|
args = tuple(inputs) + tuple(params)
|
||||||
|
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||||
|
else:
|
||||||
|
return func(*inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, run_function, length, *args):
|
||||||
|
ctx.run_function = run_function
|
||||||
|
ctx.input_tensors = list(args[:length])
|
||||||
|
ctx.input_params = list(args[length:])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||||
|
return output_tensors
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *output_grads):
|
||||||
|
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||||
|
with torch.enable_grad():
|
||||||
|
# Fixes a bug where the first op in run_function modifies the
|
||||||
|
# Tensor storage in place, which is not allowed for detach()'d
|
||||||
|
# Tensors.
|
||||||
|
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||||
|
output_tensors = ctx.run_function(*shallow_copies)
|
||||||
|
input_grads = torch.autograd.grad(
|
||||||
|
output_tensors,
|
||||||
|
ctx.input_tensors + ctx.input_params,
|
||||||
|
output_grads,
|
||||||
|
allow_unused=True,
|
||||||
|
)
|
||||||
|
del ctx.input_tensors
|
||||||
|
del ctx.input_params
|
||||||
|
del output_tensors
|
||||||
|
return (None, None) + input_grads
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an [N x dim] Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
if not repeat_only:
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||||
|
).to(device=timesteps.device)
|
||||||
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
else:
|
||||||
|
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
"""
|
||||||
|
Zero out the parameters of a module and return it.
|
||||||
|
"""
|
||||||
|
for p in module.parameters():
|
||||||
|
p.detach().zero_()
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def scale_module(module, scale):
|
||||||
|
"""
|
||||||
|
Scale the parameters of a module and return it.
|
||||||
|
"""
|
||||||
|
for p in module.parameters():
|
||||||
|
p.detach().mul_(scale)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def mean_flat(tensor):
|
||||||
|
"""
|
||||||
|
Take the mean over all non-batch dimensions.
|
||||||
|
"""
|
||||||
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||||
|
|
||||||
|
|
||||||
|
def normalization(channels):
|
||||||
|
"""
|
||||||
|
Make a standard normalization layer.
|
||||||
|
:param channels: number of input channels.
|
||||||
|
:return: an nn.Module for normalization.
|
||||||
|
"""
|
||||||
|
return GroupNorm32(32, channels)
|
||||||
|
|
||||||
|
|
||||||
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||||
|
class SiLU(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNorm32(nn.GroupNorm):
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
def conv_nd(dims, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a 1D, 2D, or 3D convolution module.
|
||||||
|
"""
|
||||||
|
if dims == 1:
|
||||||
|
return nn.Conv1d(*args, **kwargs)
|
||||||
|
elif dims == 2:
|
||||||
|
return nn.Conv2d(*args, **kwargs)
|
||||||
|
elif dims == 3:
|
||||||
|
return nn.Conv3d(*args, **kwargs)
|
||||||
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
|
||||||
|
def linear(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a linear module.
|
||||||
|
"""
|
||||||
|
return nn.Linear(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def avg_pool_nd(dims, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a 1D, 2D, or 3D average pooling module.
|
||||||
|
"""
|
||||||
|
if dims == 1:
|
||||||
|
return nn.AvgPool1d(*args, **kwargs)
|
||||||
|
elif dims == 2:
|
||||||
|
return nn.AvgPool2d(*args, **kwargs)
|
||||||
|
elif dims == 3:
|
||||||
|
return nn.AvgPool3d(*args, **kwargs)
|
||||||
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
|
||||||
|
class HybridConditioner(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, c_concat_config, c_crossattn_config):
|
||||||
|
super().__init__()
|
||||||
|
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||||
|
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||||
|
|
||||||
|
def forward(self, c_concat, c_crossattn):
|
||||||
|
c_concat = self.concat_conditioner(c_concat)
|
||||||
|
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||||
|
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||||
|
|
||||||
|
|
||||||
|
def noise_like(shape, device, repeat=False):
|
||||||
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||||
|
noise = lambda: torch.randn(shape, device=device)
|
||||||
|
return repeat_noise() if repeat else noise()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,92 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractDistribution:
|
||||||
|
def sample(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DiracDistribution(AbstractDistribution):
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution(object):
|
||||||
|
def __init__(self, parameters, deterministic=False):
|
||||||
|
self.parameters = parameters
|
||||||
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
self.var = torch.exp(self.logvar)
|
||||||
|
if self.deterministic:
|
||||||
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def kl(self, other=None):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.])
|
||||||
|
else:
|
||||||
|
if other is None:
|
||||||
|
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||||
|
+ self.var - 1.0 - self.logvar,
|
||||||
|
dim=[1, 2, 3])
|
||||||
|
else:
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
|
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||||
|
dim=[1, 2, 3])
|
||||||
|
|
||||||
|
def nll(self, sample, dims=[1,2,3]):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.])
|
||||||
|
logtwopi = np.log(2.0 * np.pi)
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||||
|
dim=dims)
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||||
|
"""
|
||||||
|
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||||
|
Compute the KL divergence between two gaussians.
|
||||||
|
Shapes are automatically broadcasted, so batches can be compared to
|
||||||
|
scalars, among other use cases.
|
||||||
|
"""
|
||||||
|
tensor = None
|
||||||
|
for obj in (mean1, logvar1, mean2, logvar2):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
tensor = obj
|
||||||
|
break
|
||||||
|
assert tensor is not None, "at least one argument must be a Tensor"
|
||||||
|
|
||||||
|
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||||
|
# Tensors, but it does not work for torch.exp().
|
||||||
|
logvar1, logvar2 = [
|
||||||
|
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||||
|
for x in (logvar1, logvar2)
|
||||||
|
]
|
||||||
|
|
||||||
|
return 0.5 * (
|
||||||
|
-1.0
|
||||||
|
+ logvar2
|
||||||
|
- logvar1
|
||||||
|
+ torch.exp(logvar1 - logvar2)
|
||||||
|
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||||
|
)
|
|
@ -0,0 +1,76 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class LitEma(nn.Module):
|
||||||
|
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||||
|
super().__init__()
|
||||||
|
if decay < 0.0 or decay > 1.0:
|
||||||
|
raise ValueError('Decay must be between 0 and 1')
|
||||||
|
|
||||||
|
self.m_name2s_name = {}
|
||||||
|
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||||
|
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
|
||||||
|
else torch.tensor(-1,dtype=torch.int))
|
||||||
|
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
#remove as '.'-character is not allowed in buffers
|
||||||
|
s_name = name.replace('.','')
|
||||||
|
self.m_name2s_name.update({name:s_name})
|
||||||
|
self.register_buffer(s_name,p.clone().detach().data)
|
||||||
|
|
||||||
|
self.collected_params = []
|
||||||
|
|
||||||
|
def forward(self,model):
|
||||||
|
decay = self.decay
|
||||||
|
|
||||||
|
if self.num_updates >= 0:
|
||||||
|
self.num_updates += 1
|
||||||
|
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
|
||||||
|
|
||||||
|
one_minus_decay = 1.0 - decay
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
m_param = dict(model.named_parameters())
|
||||||
|
shadow_params = dict(self.named_buffers())
|
||||||
|
|
||||||
|
for key in m_param:
|
||||||
|
if m_param[key].requires_grad:
|
||||||
|
sname = self.m_name2s_name[key]
|
||||||
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||||
|
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||||
|
else:
|
||||||
|
assert not key in self.m_name2s_name
|
||||||
|
|
||||||
|
def copy_to(self, model):
|
||||||
|
m_param = dict(model.named_parameters())
|
||||||
|
shadow_params = dict(self.named_buffers())
|
||||||
|
for key in m_param:
|
||||||
|
if m_param[key].requires_grad:
|
||||||
|
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||||
|
else:
|
||||||
|
assert not key in self.m_name2s_name
|
||||||
|
|
||||||
|
def store(self, parameters):
|
||||||
|
"""
|
||||||
|
Save the current parameters for restoring later.
|
||||||
|
Args:
|
||||||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||||
|
temporarily stored.
|
||||||
|
"""
|
||||||
|
self.collected_params = [param.clone() for param in parameters]
|
||||||
|
|
||||||
|
def restore(self, parameters):
|
||||||
|
"""
|
||||||
|
Restore the parameters stored with the `store` method.
|
||||||
|
Useful to validate the model with EMA parameters without affecting the
|
||||||
|
original optimization process. Store the parameters before the
|
||||||
|
`copy_to` method. After validation (or model saving), use this to
|
||||||
|
restore the former parameters.
|
||||||
|
Args:
|
||||||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||||
|
updated with the stored parameters.
|
||||||
|
"""
|
||||||
|
for c_param, param in zip(self.collected_params, parameters):
|
||||||
|
param.data.copy_(c_param.data)
|
|
@ -0,0 +1,161 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ldm.data.personalized import per_img_token_list
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
DEFAULT_PLACEHOLDER_TOKEN = ["*"]
|
||||||
|
|
||||||
|
PROGRESSIVE_SCALE = 2000
|
||||||
|
|
||||||
|
def get_clip_token_for_string(tokenizer, string):
|
||||||
|
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
|
||||||
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"]
|
||||||
|
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
|
||||||
|
|
||||||
|
return tokens[0, 1]
|
||||||
|
|
||||||
|
def get_bert_token_for_string(tokenizer, string):
|
||||||
|
token = tokenizer(string)
|
||||||
|
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||||
|
|
||||||
|
token = token[0, 1]
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def get_embedding_for_clip_token(embedder, token):
|
||||||
|
return embedder(token.unsqueeze(0))[0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingManager(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedder,
|
||||||
|
placeholder_strings=None,
|
||||||
|
initializer_words=None,
|
||||||
|
per_image_tokens=False,
|
||||||
|
num_vectors_per_token=1,
|
||||||
|
progressive_words=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.string_to_token_dict = {}
|
||||||
|
|
||||||
|
self.string_to_param_dict = nn.ParameterDict()
|
||||||
|
|
||||||
|
self.initial_embeddings = nn.ParameterDict() # These should not be optimized
|
||||||
|
|
||||||
|
self.progressive_words = progressive_words
|
||||||
|
self.progressive_counter = 0
|
||||||
|
|
||||||
|
self.max_vectors_per_token = num_vectors_per_token
|
||||||
|
|
||||||
|
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
|
||||||
|
self.is_clip = True
|
||||||
|
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
||||||
|
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)
|
||||||
|
token_dim = 768
|
||||||
|
else: # using LDM's BERT encoder
|
||||||
|
self.is_clip = False
|
||||||
|
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
|
||||||
|
get_embedding_for_tkn = embedder.transformer.token_emb
|
||||||
|
token_dim = 1280
|
||||||
|
|
||||||
|
if per_image_tokens:
|
||||||
|
placeholder_strings.extend(per_img_token_list)
|
||||||
|
|
||||||
|
for idx, placeholder_string in enumerate(placeholder_strings):
|
||||||
|
|
||||||
|
token = get_token_for_string(placeholder_string)
|
||||||
|
|
||||||
|
if initializer_words and idx < len(initializer_words):
|
||||||
|
init_word_token = get_token_for_string(initializer_words[idx])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())
|
||||||
|
|
||||||
|
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
|
||||||
|
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
|
||||||
|
else:
|
||||||
|
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
|
||||||
|
|
||||||
|
self.string_to_token_dict[placeholder_string] = token
|
||||||
|
self.string_to_param_dict[placeholder_string] = token_params
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokenized_text,
|
||||||
|
embedded_text,
|
||||||
|
):
|
||||||
|
b, n, device = *tokenized_text.shape, tokenized_text.device
|
||||||
|
|
||||||
|
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
|
||||||
|
|
||||||
|
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
|
||||||
|
|
||||||
|
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
|
||||||
|
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
|
||||||
|
embedded_text[placeholder_idx] = placeholder_embedding
|
||||||
|
else: # otherwise, need to insert and keep track of changing indices
|
||||||
|
if self.progressive_words:
|
||||||
|
self.progressive_counter += 1
|
||||||
|
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
|
||||||
|
else:
|
||||||
|
max_step_tokens = self.max_vectors_per_token
|
||||||
|
|
||||||
|
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
|
||||||
|
|
||||||
|
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
|
||||||
|
|
||||||
|
if placeholder_rows.nelement() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
|
||||||
|
sorted_rows = placeholder_rows[sort_idx]
|
||||||
|
|
||||||
|
for idx in range(len(sorted_rows)):
|
||||||
|
row = sorted_rows[idx]
|
||||||
|
col = sorted_cols[idx]
|
||||||
|
|
||||||
|
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
|
||||||
|
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
|
||||||
|
|
||||||
|
embedded_text[row] = new_embed_row
|
||||||
|
tokenized_text[row] = new_token_row
|
||||||
|
|
||||||
|
return embedded_text
|
||||||
|
|
||||||
|
def save(self, ckpt_path):
|
||||||
|
torch.save({"string_to_token": self.string_to_token_dict,
|
||||||
|
"string_to_param": self.string_to_param_dict}, ckpt_path)
|
||||||
|
|
||||||
|
def load(self, ckpt_path):
|
||||||
|
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||||
|
|
||||||
|
self.string_to_token_dict = ckpt["string_to_token"]
|
||||||
|
self.string_to_param_dict = ckpt["string_to_param"]
|
||||||
|
|
||||||
|
def get_embedding_norms_squared(self):
|
||||||
|
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
|
||||||
|
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
|
||||||
|
|
||||||
|
return param_norm_squared
|
||||||
|
|
||||||
|
def embedding_parameters(self):
|
||||||
|
return self.string_to_param_dict.parameters()
|
||||||
|
|
||||||
|
def embedding_to_coarse_loss(self):
|
||||||
|
|
||||||
|
loss = 0.
|
||||||
|
num_embeddings = len(self.initial_embeddings)
|
||||||
|
|
||||||
|
for key in self.initial_embeddings:
|
||||||
|
optimized = self.string_to_param_dict[key]
|
||||||
|
coarse = self.initial_embeddings[key].clone().to(optimized.device)
|
||||||
|
|
||||||
|
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
|
||||||
|
|
||||||
|
return loss
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,396 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from functools import partial
|
||||||
|
import clip
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
import kornia
|
||||||
|
|
||||||
|
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||||
|
|
||||||
|
def _expand_mask(mask, dtype, tgt_len = None):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||||
|
|
||||||
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
def _build_causal_attention_mask(bsz, seq_len, dtype):
|
||||||
|
# lazily create causal attention mask, with full attention between the vision tokens
|
||||||
|
# pytorch uses additive attention mask; fill with -inf
|
||||||
|
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
||||||
|
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
||||||
|
mask.triu_(1) # zero out the lower diagonal
|
||||||
|
mask = mask.unsqueeze(1) # expand mask
|
||||||
|
return mask
|
||||||
|
|
||||||
|
class AbstractEncoder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def encode(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ClassEmbedder(nn.Module):
|
||||||
|
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||||
|
super().__init__()
|
||||||
|
self.key = key
|
||||||
|
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, batch, key=None):
|
||||||
|
if key is None:
|
||||||
|
key = self.key
|
||||||
|
# this is for use in crossattn
|
||||||
|
c = batch[key][:, None]
|
||||||
|
c = self.embedding(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEmbedder(AbstractEncoder):
|
||||||
|
"""Some transformer encoder layers"""
|
||||||
|
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||||
|
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||||
|
|
||||||
|
def forward(self, tokens):
|
||||||
|
tokens = tokens.to(self.device) # meh
|
||||||
|
z = self.transformer(tokens, return_embeddings=True)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self(x)
|
||||||
|
|
||||||
|
|
||||||
|
class BERTTokenizer(AbstractEncoder):
|
||||||
|
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||||
|
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||||
|
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
|
self.device = device
|
||||||
|
self.vq_interface = vq_interface
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, text):
|
||||||
|
tokens = self(text)
|
||||||
|
if not self.vq_interface:
|
||||||
|
return tokens
|
||||||
|
return None, None, [None, None, tokens]
|
||||||
|
|
||||||
|
def decode(self, text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class BERTEmbedder(AbstractEncoder):
|
||||||
|
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||||
|
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
||||||
|
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.use_tknz_fn = use_tokenizer
|
||||||
|
if self.use_tknz_fn:
|
||||||
|
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||||
|
self.device = device
|
||||||
|
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||||
|
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||||
|
emb_dropout=embedding_dropout)
|
||||||
|
|
||||||
|
def forward(self, text, embedding_manager=None):
|
||||||
|
if self.use_tknz_fn:
|
||||||
|
tokens = self.tknz_fn(text)#.to(self.device)
|
||||||
|
else:
|
||||||
|
tokens = text
|
||||||
|
z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
# output of length 77
|
||||||
|
return self(text, **kwargs)
|
||||||
|
|
||||||
|
class SpatialRescaler(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
n_stages=1,
|
||||||
|
method='bilinear',
|
||||||
|
multiplier=0.5,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=None,
|
||||||
|
bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.n_stages = n_stages
|
||||||
|
assert self.n_stages >= 0
|
||||||
|
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||||
|
self.remap_output = out_channels is not None
|
||||||
|
if self.remap_output:
|
||||||
|
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
||||||
|
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
for stage in range(self.n_stages):
|
||||||
|
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||||
|
|
||||||
|
|
||||||
|
if self.remap_output:
|
||||||
|
x = self.channel_mapper(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self(x)
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||||
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def embedding_forward(
|
||||||
|
self,
|
||||||
|
input_ids = None,
|
||||||
|
position_ids = None,
|
||||||
|
inputs_embeds = None,
|
||||||
|
embedding_manager = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, :seq_length]
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.token_embedding(input_ids)
|
||||||
|
|
||||||
|
if embedding_manager is not None:
|
||||||
|
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
|
||||||
|
|
||||||
|
|
||||||
|
position_embeddings = self.position_embedding(position_ids)
|
||||||
|
embeddings = inputs_embeds + position_embeddings
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
|
||||||
|
|
||||||
|
def encoder_forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask = None,
|
||||||
|
causal_attention_mask = None,
|
||||||
|
output_attentions = None,
|
||||||
|
output_hidden_states = None,
|
||||||
|
return_dict = None,
|
||||||
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
encoder_states = () if output_hidden_states else None
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
causal_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
|
||||||
|
|
||||||
|
|
||||||
|
def text_encoder_forward(
|
||||||
|
self,
|
||||||
|
input_ids = None,
|
||||||
|
attention_mask = None,
|
||||||
|
position_ids = None,
|
||||||
|
output_attentions = None,
|
||||||
|
output_hidden_states = None,
|
||||||
|
return_dict = None,
|
||||||
|
embedding_manager = None,
|
||||||
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
raise ValueError("You have to specify either input_ids")
|
||||||
|
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
|
||||||
|
|
||||||
|
bsz, seq_len = input_shape
|
||||||
|
# CLIP's text model uses causal mask, prepare it here.
|
||||||
|
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||||
|
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
||||||
|
hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# expand attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||||
|
|
||||||
|
last_hidden_state = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
causal_attention_mask=causal_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
|
||||||
|
|
||||||
|
def transformer_forward(
|
||||||
|
self,
|
||||||
|
input_ids = None,
|
||||||
|
attention_mask = None,
|
||||||
|
position_ids = None,
|
||||||
|
output_attentions = None,
|
||||||
|
output_hidden_states = None,
|
||||||
|
return_dict = None,
|
||||||
|
embedding_manager = None,
|
||||||
|
):
|
||||||
|
return self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
embedding_manager = embedding_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
self.transformer.forward = transformer_forward.__get__(self.transformer)
|
||||||
|
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.transformer = self.transformer.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, text, **kwargs):
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
|
z = self.transformer(input_ids=tokens, **kwargs)
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
return self(text, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPTextEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Uses the CLIP transformer encoder for text.
|
||||||
|
"""
|
||||||
|
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
|
||||||
|
super().__init__()
|
||||||
|
self.model, _ = clip.load(version, jit=False, device="cpu")
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
self.n_repeat = n_repeat
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.model = self.model.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
tokens = clip.tokenize(text).to(self.device)
|
||||||
|
z = self.model.encode_text(tokens)
|
||||||
|
if self.normalize:
|
||||||
|
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
z = self(text)
|
||||||
|
if z.ndim==2:
|
||||||
|
z = z[:, None, :]
|
||||||
|
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenClipImageEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Uses the CLIP image encoder.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
jit=False,
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
antialias=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
interpolation='bicubic',align_corners=True,
|
||||||
|
antialias=self.antialias)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# renormalize according to clip
|
||||||
|
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x is assumed to be in range [-1,1]
|
||||||
|
return self.model.encode_image(self.preprocess(x))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from ldm.util import count_params
|
||||||
|
model = FrozenCLIPEmbedder()
|
||||||
|
count_params(model, verbose=True)
|
|
@ -0,0 +1,496 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from functools import partial
|
||||||
|
import clip
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
import kornia
|
||||||
|
|
||||||
|
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||||
|
|
||||||
|
def _expand_mask(mask, dtype, tgt_len = None):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||||
|
|
||||||
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
def _build_causal_attention_mask(bsz, seq_len, dtype):
|
||||||
|
# lazily create causal attention mask, with full attention between the vision tokens
|
||||||
|
# pytorch uses additive attention mask; fill with -inf
|
||||||
|
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
||||||
|
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
||||||
|
mask.triu_(1) # zero out the lower diagonal
|
||||||
|
mask = mask.unsqueeze(1) # expand mask
|
||||||
|
return mask
|
||||||
|
|
||||||
|
class AbstractEncoder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def encode(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ClassEmbedder(nn.Module):
|
||||||
|
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||||
|
super().__init__()
|
||||||
|
self.key = key
|
||||||
|
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, batch, key=None):
|
||||||
|
if key is None:
|
||||||
|
key = self.key
|
||||||
|
# this is for use in crossattn
|
||||||
|
c = batch[key][:, None]
|
||||||
|
c = self.embedding(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEmbedder(AbstractEncoder):
|
||||||
|
"""Some transformer encoder layers"""
|
||||||
|
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||||
|
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||||
|
|
||||||
|
def forward(self, tokens):
|
||||||
|
tokens = tokens.to(self.device) # meh
|
||||||
|
z = self.transformer(tokens, return_embeddings=True)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self(x)
|
||||||
|
|
||||||
|
|
||||||
|
class BERTTokenizer(AbstractEncoder):
|
||||||
|
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||||
|
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||||
|
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
|
self.device = device
|
||||||
|
self.vq_interface = vq_interface
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, text):
|
||||||
|
tokens = self(text)
|
||||||
|
if not self.vq_interface:
|
||||||
|
return tokens
|
||||||
|
return None, None, [None, None, tokens]
|
||||||
|
|
||||||
|
def decode(self, text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class BERTEmbedder(AbstractEncoder):
|
||||||
|
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||||
|
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
||||||
|
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.use_tknz_fn = use_tokenizer
|
||||||
|
if self.use_tknz_fn:
|
||||||
|
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||||
|
self.device = device
|
||||||
|
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||||
|
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||||
|
emb_dropout=embedding_dropout)
|
||||||
|
|
||||||
|
def forward(self, text, embedding_manager=None):
|
||||||
|
if self.use_tknz_fn:
|
||||||
|
tokens = self.tknz_fn(text)#.to(self.device)
|
||||||
|
else:
|
||||||
|
tokens = text
|
||||||
|
z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
# output of length 77
|
||||||
|
return self(text, **kwargs)
|
||||||
|
|
||||||
|
class SpatialRescaler(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
n_stages=1,
|
||||||
|
method='bilinear',
|
||||||
|
multiplier=0.5,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=None,
|
||||||
|
bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.n_stages = n_stages
|
||||||
|
assert self.n_stages >= 0
|
||||||
|
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||||
|
self.remap_output = out_channels is not None
|
||||||
|
if self.remap_output:
|
||||||
|
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
||||||
|
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
for stage in range(self.n_stages):
|
||||||
|
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||||
|
|
||||||
|
|
||||||
|
if self.remap_output:
|
||||||
|
x = self.channel_mapper(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self(x)
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||||
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
|
def embedding_forward(
|
||||||
|
self,
|
||||||
|
input_ids = None,
|
||||||
|
position_ids = None,
|
||||||
|
inputs_embeds = None,
|
||||||
|
embedding_manager = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, :seq_length]
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.token_embedding(input_ids)
|
||||||
|
|
||||||
|
if embedding_manager is not None:
|
||||||
|
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
|
||||||
|
|
||||||
|
|
||||||
|
position_embeddings = self.position_embedding(position_ids)
|
||||||
|
embeddings = inputs_embeds + position_embeddings
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
|
||||||
|
|
||||||
|
def encoder_forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask = None,
|
||||||
|
causal_attention_mask = None,
|
||||||
|
output_attentions = None,
|
||||||
|
output_hidden_states = None,
|
||||||
|
return_dict = None,
|
||||||
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
encoder_states = () if output_hidden_states else None
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
causal_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
|
||||||
|
|
||||||
|
|
||||||
|
def text_encoder_forward(
|
||||||
|
self,
|
||||||
|
input_ids = None,
|
||||||
|
attention_mask = None,
|
||||||
|
position_ids = None,
|
||||||
|
output_attentions = None,
|
||||||
|
output_hidden_states = None,
|
||||||
|
return_dict = None,
|
||||||
|
embedding_manager = None,
|
||||||
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
raise ValueError("You have to specify either input_ids")
|
||||||
|
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
|
||||||
|
|
||||||
|
bsz, seq_len = input_shape
|
||||||
|
# CLIP's text model uses causal mask, prepare it here.
|
||||||
|
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||||
|
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
||||||
|
hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# expand attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||||
|
|
||||||
|
last_hidden_state = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
causal_attention_mask=causal_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
|
||||||
|
|
||||||
|
def transformer_forward(
|
||||||
|
self,
|
||||||
|
input_ids = None,
|
||||||
|
attention_mask = None,
|
||||||
|
position_ids = None,
|
||||||
|
output_attentions = None,
|
||||||
|
output_hidden_states = None,
|
||||||
|
return_dict = None,
|
||||||
|
embedding_manager = None,
|
||||||
|
):
|
||||||
|
return self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
embedding_manager = embedding_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
self.transformer.forward = transformer_forward.__get__(self.transformer)
|
||||||
|
|
||||||
|
|
||||||
|
# def update_embedding_func(self, embedding_manager):
|
||||||
|
# text_model = self.transformer.text_model
|
||||||
|
# # text_model.old_embeddings = text_model.embeddings
|
||||||
|
|
||||||
|
# # def new_embeddings(
|
||||||
|
# # input_ids = None,
|
||||||
|
# # position_ids = None,
|
||||||
|
# # inputs_embeds = None,
|
||||||
|
# # ) -> torch.Tensor:
|
||||||
|
|
||||||
|
# # seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||||
|
|
||||||
|
# # if position_ids is None:
|
||||||
|
# # position_ids = text_model.old_embeddings.position_ids[:, :seq_length]
|
||||||
|
|
||||||
|
# # if inputs_embeds is None:
|
||||||
|
# # inputs_embeds = text_model.old_embeddings.token_embedding(input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
# # inputs_embeds = embedding_manager(input_ids, inputs_embeds)
|
||||||
|
|
||||||
|
# # position_embeddings = text_model.old_embeddings.position_embedding(position_ids)
|
||||||
|
# # embeddings = inputs_embeds + position_embeddings
|
||||||
|
|
||||||
|
# # return embeddings
|
||||||
|
|
||||||
|
# # del text_model.embeddings
|
||||||
|
# # text_model.embeddings = new_embeddings
|
||||||
|
|
||||||
|
# # class NewEmbeddings(torch.nn.Module):
|
||||||
|
|
||||||
|
# # def __init__(self, orig_embedder):
|
||||||
|
# # super().__init__()
|
||||||
|
# # self.orig_embedder = orig_embedder
|
||||||
|
|
||||||
|
# # def forward(
|
||||||
|
# # self,
|
||||||
|
# # input_ids = None,
|
||||||
|
# # position_ids = None,
|
||||||
|
# # inputs_embeds = None,
|
||||||
|
# # ) -> torch.Tensor:
|
||||||
|
|
||||||
|
# # seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||||
|
|
||||||
|
# # if position_ids is None:
|
||||||
|
# # position_ids = self.orig_embedder.position_ids[:, :seq_length]
|
||||||
|
|
||||||
|
# # if inputs_embeds is None:
|
||||||
|
# # inputs_embeds = self.orig_embedder.token_embedding(input_ids)
|
||||||
|
|
||||||
|
# # inputs_embeds = embedding_manager(input_ids, inputs_embeds)
|
||||||
|
|
||||||
|
# # position_embeddings = self.orig_embedder.position_embedding(position_ids)
|
||||||
|
# # embeddings = inputs_embeds + position_embeddings
|
||||||
|
|
||||||
|
# # return embeddings
|
||||||
|
|
||||||
|
# # # self.new_embeddings =
|
||||||
|
# # # text_model.embeddings = new_embeddings.__call__.__get__(text_model)
|
||||||
|
# # text_model.embeddings = NewEmbeddings(text_model.embeddings)
|
||||||
|
|
||||||
|
# class NewEmbeddings(torch.nn.Module):
|
||||||
|
|
||||||
|
# def __init__(self, orig_embedder, embedding_manager):
|
||||||
|
# super().__init__()
|
||||||
|
# self.embedding_manager = embedding_manager
|
||||||
|
# self.orig_embedder = orig_embedder
|
||||||
|
|
||||||
|
# def forward(
|
||||||
|
# self,
|
||||||
|
# input_ids = None,
|
||||||
|
# position_ids = None,
|
||||||
|
# inputs_embeds = None,
|
||||||
|
# ) -> torch.Tensor:
|
||||||
|
|
||||||
|
# seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||||
|
|
||||||
|
# if position_ids is None:
|
||||||
|
# position_ids = self.orig_embedder.position_ids[:, :seq_length]
|
||||||
|
|
||||||
|
# if inputs_embeds is None:
|
||||||
|
# inputs_embeds = self.orig_embedder.token_embedding(input_ids)
|
||||||
|
|
||||||
|
# # init_embeds = inputs_embeds.clone()
|
||||||
|
# inputs_embeds = self.embedding_manager(input_ids, inputs_embeds)
|
||||||
|
|
||||||
|
# # print(inputs_embeds - init_embeds)
|
||||||
|
# # print((inputs_embeds - init_embeds).max())
|
||||||
|
# # exit(0)
|
||||||
|
|
||||||
|
# position_embeddings = self.orig_embedder.position_embedding(position_ids)
|
||||||
|
# embeddings = inputs_embeds + position_embeddings
|
||||||
|
|
||||||
|
# return embeddings
|
||||||
|
|
||||||
|
# # self.new_embeddings =
|
||||||
|
# # text_model.embeddings = new_embeddings.__call__.__get__(text_model)
|
||||||
|
# text_model.embeddings = NewEmbeddings(text_model.embeddings, embedding_manager)
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.transformer = self.transformer.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, text, **kwargs):
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
|
z = self.transformer(input_ids=tokens, **kwargs)
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
return self(text, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPTextEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Uses the CLIP transformer encoder for text.
|
||||||
|
"""
|
||||||
|
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
|
||||||
|
super().__init__()
|
||||||
|
self.model, _ = clip.load(version, jit=False, device="cpu")
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
self.n_repeat = n_repeat
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.model = self.model.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
tokens = clip.tokenize(text).to(self.device)
|
||||||
|
z = self.model.encode_text(tokens)
|
||||||
|
if self.normalize:
|
||||||
|
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
z = self(text)
|
||||||
|
if z.ndim==2:
|
||||||
|
z = z[:, None, :]
|
||||||
|
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenClipImageEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Uses the CLIP image encoder.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
jit=False,
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
antialias=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
interpolation='bicubic',align_corners=True,
|
||||||
|
antialias=self.antialias)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# renormalize according to clip
|
||||||
|
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x is assumed to be in range [-1,1]
|
||||||
|
return self.model.encode_image(self.preprocess(x))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from ldm.util import count_params
|
||||||
|
model = FrozenCLIPEmbedder()
|
||||||
|
count_params(model, verbose=True)
|
|
@ -0,0 +1,2 @@
|
||||||
|
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
|
||||||
|
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue