From 3100bc967084964480628ae61210b7eaa7436f1d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 28 Jul 2022 10:14:34 +0200 Subject: [PATCH] [Vae and AutoencoderKL] Final clean of LDM checkpoints (#137) * [Vae and AutoencoderKL clean] * save intermediate finished work * more progress * more progress * finish modeling code * save intermediate * finish * Correct tests --- ...t_ddpm_original_checkpoint_to_diffusers.py | 181 ++++++-- src/diffusers/models/resnet.py | 10 +- src/diffusers/models/unet_blocks.py | 133 ++++++ src/diffusers/models/vae.py | 424 +++++++----------- src/diffusers/training_utils.py | 2 +- tests/test_modeling_utils.py | 30 +- 6 files changed, 479 insertions(+), 301 deletions(-) diff --git a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py index 216018c6..88cd92d8 100644 --- a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py @@ -1,4 +1,4 @@ -from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline +from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL import argparse import json import torch @@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) @@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue - new_path = new_path.replace('down.', 'downsample_blocks.') + new_path = new_path.replace('down.', 'down_blocks.') new_path = new_path.replace('up.', 'up_blocks.') if additional_replacements is not None: @@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config): new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight'] new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias'] - num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer}) - downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)} + num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer}) + down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)} num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer}) up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} - for i in range(num_downsample_blocks): - block_id = (i - 1) // (config['num_res_blocks'] + 1) + for i in range(num_down_blocks): + block_id = (i - 1) // (config['layers_per_block'] + 1) - if any('downsample' in layer for layer in downsample_blocks[i]): - new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] - new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] - new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] - new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] + if any('downsample' in layer for layer in down_blocks[i]): + new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight'] + new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias'] +# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] +# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] - if any('block' in layer for layer in downsample_blocks[i]): - num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'block' in layer}) - blocks = {layer_id: [key for key in downsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} + if any('block' in layer for layer in down_blocks[i]): + num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer}) + blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} if num_blocks > 0: - for j in range(config['num_res_blocks']): + for j in range(config['layers_per_block']): paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint) - if any('attn' in layer for layer in downsample_blocks[i]): - num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer}) - attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} + if any('attn' in layer for layer in down_blocks[i]): + num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer}) + attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} if num_attn > 0: - for j in range(config['num_res_blocks']): + for j in range(config['layers_per_block']): paths = renew_attention_paths(attns[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) @@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config): blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} if num_blocks > 0: - for j in range(config['num_res_blocks'] + 1): + for j in range(config['layers_per_block'] + 1): replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) @@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config): attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} if num_attn > 0: - for j in range(config['num_res_blocks'] + 1): + for j in range(config['layers_per_block'] + 1): replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} paths = renew_attention_paths(attns[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) @@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config): return new_checkpoint +def convert_vq_autoenc_checkpoint(checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight'] + new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias'] + + new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight'] + new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias'] + new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight'] + new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias'] + + new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight'] + new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias'] + + new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight'] + new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias'] + new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight'] + new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias'] + + num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer}) + down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)} + + num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer}) + up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + block_id = (i - 1) // (config['layers_per_block'] + 1) + + if any('downsample' in layer for layer in down_blocks[i]): + new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight'] + new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias'] + + if any('block' in layer for layer in down_blocks[i]): + num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer}) + blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} + + if num_blocks > 0: + for j in range(config['layers_per_block']): + paths = renew_resnet_paths(blocks[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint) + + if any('attn' in layer for layer in down_blocks[i]): + num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer}) + attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} + + if num_attn > 0: + for j in range(config['layers_per_block']): + paths = renew_attention_paths(attns[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) + + mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key] + mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key] + mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key] + + # Mid new 2 + paths = renew_resnet_paths(mid_block_1_layers) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ + {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'} + ]) + + paths = renew_resnet_paths(mid_block_2_layers) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ + {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'} + ]) + + paths = renew_attention_paths(mid_attn_1_layers, in_mid=True) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ + {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'} + ]) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + + if any('upsample' in layer for layer in up_blocks[i]): + new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight'] + new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias'] + + if any('block' in layer for layer in up_blocks[i]): + num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer}) + blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} + + if num_blocks > 0: + for j in range(config['layers_per_block'] + 1): + replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} + paths = renew_resnet_paths(blocks[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) + + if any('attn' in layer for layer in up_blocks[i]): + num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer}) + attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} + + if num_attn > 0: + for j in range(config['layers_per_block'] + 1): + replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} + paths = renew_attention_paths(attns[j]) + assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) + + new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()} + new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"] + if "quantize.embedding.weight" in checkpoint: + new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"] + new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"] + + return new_checkpoint + + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -220,15 +331,29 @@ if __name__ == "__main__": with open(args.config_file) as f: config = json.loads(f.read()) - converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config) + # unet case + key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys()) + if "encoder" in key_prefix_set and "decoder" in key_prefix_set: + converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config) + else: + converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config) if "ddpm" in config: del config["ddpm"] - model = UNet2DModel(**config) - model.load_state_dict(converted_checkpoint) + if config["_class_name"] == "VQModel": + model = VQModel(**config) + model.load_state_dict(converted_checkpoint) + model.save_pretrained(args.dump_path) + elif config["_class_name"] == "AutoencoderKL": + model = AutoencoderKL(**config) + model.load_state_dict(converted_checkpoint) + model.save_pretrained(args.dump_path) + else: + model = UNet2DModel(**config) + model.load_state_dict(converted_checkpoint) - scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) + scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) - pipe = DDPMPipeline(unet=model, scheduler=scheduler) - pipe.save_pretrained(args.dump_path) + pipe = DDPMPipeline(unet=model, scheduler=scheduler) + pipe.save_pretrained(args.dump_path) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index a54199c1..98244261 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -288,7 +288,10 @@ class ResnetBlock(nn.Module): self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) @@ -364,8 +367,9 @@ class ResnetBlock(nn.Module): self.conv1.weight.data = resnet.conv1.weight.data self.conv1.bias.data = resnet.conv1.bias.data - self.time_emb_proj.weight.data = resnet.temb_proj.weight.data - self.time_emb_proj.bias.data = resnet.temb_proj.bias.data + if self.time_emb_proj is not None: + self.time_emb_proj.weight.data = resnet.temb_proj.weight.data + self.time_emb_proj.bias.data = resnet.temb_proj.bias.data self.norm2.weight.data = resnet.norm2.weight.data self.norm2.bias.data = resnet.norm2.bias.data diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 67082d24..034e662e 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -92,6 +92,16 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + ) def get_up_block( @@ -165,6 +175,15 @@ def get_up_block( resnet_act_fn=resnet_act_fn, attn_num_head_channels=attn_num_head_channels, ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) raise ValueError(f"{up_block_type} does not exist.") @@ -553,6 +572,66 @@ class DownBlock2D(nn.Module): return hidden_states, output_states +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + class AttnSkipDownBlock2D(nn.Module): def __init__( self, @@ -946,6 +1025,60 @@ class UpBlock2D(nn.Module): return hidden_states +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + class AttnSkipUpBlock2D(nn.Module): def __init__( self, diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index d16ab792..20a49659 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -4,221 +4,164 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin -from .attention import AttentionBlock -from .resnet import Downsample2D, ResnetBlock2D, Upsample2D - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block class Encoder(nn.Module): def __init__( self, - *, - ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + act_fn="silu", double_z=True, - **ignore_kwargs, ): super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels + self.layers_per_block = layers_per_block - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock2D( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttentionBlock(block_in, overwrite_qkv=True)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0) - curr_res = curr_res // 2 - self.down.append(down) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) - self.mid.block_2 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, ) - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 - ) + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) def forward(self, x): - # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + sample = x + sample = self.conv_in(sample) - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) + # down + for down_block in self.down_blocks: + sample = down_block(sample) # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + sample = self.mid_block(sample) - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample class Decoder(nn.Module): def __init__( self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - **ignorekwargs, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + act_fn="silu", ): super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end + self.layers_per_block = layers_per_block - # compute in_ch_mult, block_in and curr_res at lowest res - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - # print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) - self.mid.block_2 = ResnetBlock2D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=32, + temb_channels=None, ) - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock2D( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttentionBlock(block_in, overwrite_qkv=True)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = 32 + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) + sample = z + sample = self.conv_in(sample) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + sample = self.mid_block(sample) - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) + # up + for up_block in self.up_blocks: + sample = up_block(sample) - # end - if self.give_pre_end: - return h + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h + return sample class VectorQuantizer(nn.Module): @@ -383,57 +326,44 @@ class VQModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - n_embed, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=1, + act_fn="silu", + latent_channels=3, + sample_size=32, + num_vq_embeddings=256, ): super().__init__() # pass init params to Encoder self.encoder = Encoder( - ch=ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + double_z=False, ) - self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1) - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.quantize = VectorQuantizer( + num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False + ) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, ) def encode(self, x): @@ -462,57 +392,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=1, + act_fn="silu", + latent_channels=4, + sample_size=32, ): super().__init__() # pass init params to Encoder self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + double_z=True, ) # pass init params to Decoder self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, ) - self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) def encode(self, x): h = self.encoder(x) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 022be41a..fa169416 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -28,8 +28,8 @@ def enable_full_determinism(seed: int): def set_seed(seed: int): """ - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. seed (`int`): The seed to set. """ random.seed(seed) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ed98a9e5..cd5767c4 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -555,18 +555,12 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "ch": 64, - "out_ch": 3, - "num_res_blocks": 1, + "block_out_channels": [64], "in_channels": 3, - "attn_resolutions": [], - "resolution": 32, - "z_channels": 3, - "n_embed": 256, - "embed_dim": 3, - "sane_index_shape": False, - "ch_mult": (1,), - "double_z": False, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D"], + "latent_channels": 3, } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -595,7 +589,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) with torch.no_grad(): output = model(image) @@ -639,6 +633,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): "resolution": 32, "z_channels": 4, } + init_dict = { + "block_out_channels": [64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D"], + "latent_channels": 4, + } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -666,13 +668,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) with torch.no_grad(): output = model(image, sample_posterior=True) output_slice = output[0, -1, -3:, -3:].flatten() # fmt: off - expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750]) + expected_output_slice = torch.tensor([-0.3900, -0.2800, 0.1281, -0.4449, -0.4890, -0.0207, 0.0784, -0.1258, -0.0409]) # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))