[Vae and AutoencoderKL] Final clean of LDM checkpoints (#137)

* [Vae and AutoencoderKL clean]

* save intermediate finished work

* more progress

* more progress

* finish modeling code

* save intermediate

* finish

* Correct tests
This commit is contained in:
Patrick von Platen 2022-07-28 10:14:34 +02:00 committed by GitHub
parent e05f03ae41
commit 3100bc9670
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 479 additions and 301 deletions

View File

@ -1,4 +1,4 @@
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
import argparse
import json
import torch
@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
new_path = new_path.replace('down.', 'downsample_blocks.')
new_path = new_path.replace('down.', 'down_blocks.')
new_path = new_path.replace('up.', 'up_blocks.')
if additional_replacements is not None:
@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config):
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight']
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias']
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}
num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1)
if any('downsample' in layer for layer in downsample_blocks[i]):
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if any('downsample' in layer for layer in down_blocks[i]):
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight']
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if any('block' in layer for layer in downsample_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in downsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('block' in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0:
for j in range(config['num_res_blocks']):
for j in range(config['layers_per_block']):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in downsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('attn' in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0:
for j in range(config['num_res_blocks']):
for j in range(config['layers_per_block']):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0:
for j in range(config['num_res_blocks'] + 1):
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0:
for j in range(config['num_res_blocks'] + 1):
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config):
return new_checkpoint
def convert_vq_autoenc_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint = {}
new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight']
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias']
new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight']
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias']
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight']
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias']
new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight']
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias']
new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight']
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias']
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight']
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias']
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1)
if any('downsample' in layer for layer in down_blocks[i]):
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight']
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias']
if any('block' in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0:
for j in range(config['layers_per_block']):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0:
for j in range(config['layers_per_block']):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
# Mid new 2
paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
])
paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
])
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight']
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias']
if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
if "quantize.embedding.weight" in checkpoint:
new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
return new_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -220,15 +331,29 @@ if __name__ == "__main__":
with open(args.config_file) as f:
config = json.loads(f.read())
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
# unet case
key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys())
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
else:
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
if "ddpm" in config:
del config["ddpm"]
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
if config["_class_name"] == "VQModel":
model = VQModel(**config)
model.load_state_dict(converted_checkpoint)
model.save_pretrained(args.dump_path)
elif config["_class_name"] == "AutoencoderKL":
model = AutoencoderKL(**config)
model.load_state_dict(converted_checkpoint)
model.save_pretrained(args.dump_path)
else:
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)

View File

@ -288,7 +288,10 @@ class ResnetBlock(nn.Module):
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
@ -364,8 +367,9 @@ class ResnetBlock(nn.Module):
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
if self.time_emb_proj is not None:
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data

View File

@ -92,6 +92,16 @@ def get_down_block(
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "DownEncoderBlock2D":
return DownEncoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
def get_up_block(
@ -165,6 +175,15 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UpDecoderBlock2D":
return UpDecoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
raise ValueError(f"{up_block_type} does not exist.")
@ -553,6 +572,66 @@ class DownBlock2D(nn.Module):
return hidden_states, output_states
class DownEncoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class AttnSkipDownBlock2D(nn.Module):
def __init__(
self,
@ -946,6 +1025,60 @@ class UpBlock2D(nn.Module):
return hidden_states
class UpDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class AttnSkipUpBlock2D(nn.Module):
def __init__(
self,

View File

@ -4,221 +4,164 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
act_fn="silu",
double_z=True,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.layers_per_block = layers_per_block
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
self.down.append(down)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
attn_num_head_channels=None,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
temb_channels=None,
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
)
# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
def forward(self, x):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
sample = x
sample = self.conv_in(sample)
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# down
for down_block in self.down_blocks:
sample = down_block(sample)
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
sample = self.mid_block(sample)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
**ignorekwargs,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
act_fn="silu",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.layers_per_block = layers_per_block
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
temb_channels=None,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
attn_num_head_channels=None,
temb_channels=None,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
sample = z
sample = self.conv_in(sample)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
sample = self.mid_block(sample)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# up
for up_block in self.up_blocks:
sample = up_block(sample)
# end
if self.give_pre_end:
return h
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
return sample
class VectorQuantizer(nn.Module):
@ -383,57 +326,44 @@ class VQModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
ch,
out_ch,
num_res_blocks,
attn_resolutions,
in_channels,
resolution,
z_channels,
n_embed,
embed_dim,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
ch_mult=(1, 2, 4, 8),
dropout=0.0,
double_z=True,
resamp_with_conv=True,
give_pre_end=False,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=1,
act_fn="silu",
latent_channels=3,
sample_size=32,
num_vq_embeddings=256,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
ch=ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
double_z=double_z,
give_pre_end=give_pre_end,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
double_z=False,
)
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.quantize = VectorQuantizer(
num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
# pass init params to Decoder
self.decoder = Decoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
)
def encode(self, x):
@ -462,57 +392,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
ch,
out_ch,
num_res_blocks,
attn_resolutions,
in_channels,
resolution,
z_channels,
embed_dim,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
ch_mult=(1, 2, 4, 8),
dropout=0.0,
double_z=True,
resamp_with_conv=True,
give_pre_end=False,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=1,
act_fn="silu",
latent_channels=4,
sample_size=32,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
double_z=double_z,
give_pre_end=give_pre_end,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
double_z=True,
)
# pass init params to Decoder
self.decoder = Decoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
)
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
def encode(self, x):
h = self.encoder(x)

View File

@ -28,8 +28,8 @@ def enable_full_determinism(seed: int):
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
Args:
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
seed (`int`): The seed to set.
"""
random.seed(seed)

View File

@ -555,18 +555,12 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"ch": 64,
"out_ch": 3,
"num_res_blocks": 1,
"block_out_channels": [64],
"in_channels": 3,
"attn_resolutions": [],
"resolution": 32,
"z_channels": 3,
"n_embed": 256,
"embed_dim": 3,
"sane_index_shape": False,
"ch_mult": (1,),
"double_z": False,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D"],
"latent_channels": 3,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@ -595,7 +589,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
with torch.no_grad():
output = model(image)
@ -639,6 +633,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
"resolution": 32,
"z_channels": 4,
}
init_dict = {
"block_out_channels": [64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@ -666,13 +668,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
with torch.no_grad():
output = model(image, sample_posterior=True)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
expected_output_slice = torch.tensor([-0.3900, -0.2800, 0.1281, -0.4449, -0.4890, -0.0207, 0.0784, -0.1258, -0.0409])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))