926 lines
35 KiB
Python
926 lines
35 KiB
Python
"""
|
|
This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers.
|
|
|
|
It currently only supports porting the ITHQ dataset.
|
|
|
|
ITHQ dataset:
|
|
```sh
|
|
# From the root directory of diffusers.
|
|
|
|
# Download the VQVAE checkpoint
|
|
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth
|
|
|
|
# Download the VQVAE config
|
|
# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class
|
|
# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE`
|
|
# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml`
|
|
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml
|
|
|
|
# Download the main model checkpoint
|
|
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth
|
|
|
|
# Download the main model config
|
|
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml
|
|
|
|
# run the convert script
|
|
$ python ./scripts/convert_vq_diffusion_to_diffusers.py \
|
|
--checkpoint_path ./ithq_learnable.pth \
|
|
--original_config_file ./ithq.yaml \
|
|
--vqvae_checkpoint_path ./ithq_vqvae.pth \
|
|
--vqvae_original_config_file ./ithq_vqvae.yaml \
|
|
--dump_path <path to save pre-trained `VQDiffusionPipeline`>
|
|
```
|
|
"""
|
|
|
|
import argparse
|
|
import tempfile
|
|
|
|
import torch
|
|
import yaml
|
|
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
from yaml.loader import FullLoader
|
|
|
|
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
|
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
|
|
|
|
|
|
try:
|
|
from omegaconf import OmegaConf
|
|
except ImportError:
|
|
raise ImportError(
|
|
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
|
|
" OmegaConf`."
|
|
)
|
|
|
|
# vqvae model
|
|
|
|
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
|
|
|
|
|
|
def vqvae_model_from_original_config(original_config):
|
|
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
|
|
|
|
original_config = original_config.params
|
|
|
|
original_encoder_config = original_config.encoder_config.params
|
|
original_decoder_config = original_config.decoder_config.params
|
|
|
|
in_channels = original_encoder_config.in_channels
|
|
out_channels = original_decoder_config.out_ch
|
|
|
|
down_block_types = get_down_block_types(original_encoder_config)
|
|
up_block_types = get_up_block_types(original_decoder_config)
|
|
|
|
assert original_encoder_config.ch == original_decoder_config.ch
|
|
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
|
|
block_out_channels = tuple(
|
|
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
|
|
)
|
|
|
|
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
|
|
layers_per_block = original_encoder_config.num_res_blocks
|
|
|
|
assert original_encoder_config.z_channels == original_decoder_config.z_channels
|
|
latent_channels = original_encoder_config.z_channels
|
|
|
|
num_vq_embeddings = original_config.n_embed
|
|
|
|
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
|
|
norm_num_groups = 32
|
|
|
|
e_dim = original_config.embed_dim
|
|
|
|
model = VQModel(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
down_block_types=down_block_types,
|
|
up_block_types=up_block_types,
|
|
block_out_channels=block_out_channels,
|
|
layers_per_block=layers_per_block,
|
|
latent_channels=latent_channels,
|
|
num_vq_embeddings=num_vq_embeddings,
|
|
norm_num_groups=norm_num_groups,
|
|
vq_embed_dim=e_dim,
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def get_down_block_types(original_encoder_config):
|
|
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
|
|
num_resolutions = len(original_encoder_config.ch_mult)
|
|
resolution = coerce_resolution(original_encoder_config.resolution)
|
|
|
|
curr_res = resolution
|
|
down_block_types = []
|
|
|
|
for _ in range(num_resolutions):
|
|
if curr_res in attn_resolutions:
|
|
down_block_type = "AttnDownEncoderBlock2D"
|
|
else:
|
|
down_block_type = "DownEncoderBlock2D"
|
|
|
|
down_block_types.append(down_block_type)
|
|
|
|
curr_res = [r // 2 for r in curr_res]
|
|
|
|
return down_block_types
|
|
|
|
|
|
def get_up_block_types(original_decoder_config):
|
|
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
|
|
num_resolutions = len(original_decoder_config.ch_mult)
|
|
resolution = coerce_resolution(original_decoder_config.resolution)
|
|
|
|
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
|
|
up_block_types = []
|
|
|
|
for _ in reversed(range(num_resolutions)):
|
|
if curr_res in attn_resolutions:
|
|
up_block_type = "AttnUpDecoderBlock2D"
|
|
else:
|
|
up_block_type = "UpDecoderBlock2D"
|
|
|
|
up_block_types.append(up_block_type)
|
|
|
|
curr_res = [r * 2 for r in curr_res]
|
|
|
|
return up_block_types
|
|
|
|
|
|
def coerce_attn_resolutions(attn_resolutions):
|
|
attn_resolutions = OmegaConf.to_object(attn_resolutions)
|
|
attn_resolutions_ = []
|
|
for ar in attn_resolutions:
|
|
if isinstance(ar, (list, tuple)):
|
|
attn_resolutions_.append(list(ar))
|
|
else:
|
|
attn_resolutions_.append([ar, ar])
|
|
return attn_resolutions_
|
|
|
|
|
|
def coerce_resolution(resolution):
|
|
resolution = OmegaConf.to_object(resolution)
|
|
if isinstance(resolution, int):
|
|
resolution = [resolution, resolution] # H, W
|
|
elif isinstance(resolution, (tuple, list)):
|
|
resolution = list(resolution)
|
|
else:
|
|
raise ValueError("Unknown type of resolution:", resolution)
|
|
return resolution
|
|
|
|
|
|
# done vqvae model
|
|
|
|
# vqvae checkpoint
|
|
|
|
|
|
def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
|
diffusers_checkpoint = {}
|
|
|
|
diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint))
|
|
|
|
# quant_conv
|
|
|
|
diffusers_checkpoint.update(
|
|
{
|
|
"quant_conv.weight": checkpoint["quant_conv.weight"],
|
|
"quant_conv.bias": checkpoint["quant_conv.bias"],
|
|
}
|
|
)
|
|
|
|
# quantize
|
|
diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]})
|
|
|
|
# post_quant_conv
|
|
diffusers_checkpoint.update(
|
|
{
|
|
"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
|
|
"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
|
|
}
|
|
)
|
|
|
|
# decoder
|
|
diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint))
|
|
|
|
return diffusers_checkpoint
|
|
|
|
|
|
def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint):
|
|
diffusers_checkpoint = {}
|
|
|
|
# conv_in
|
|
diffusers_checkpoint.update(
|
|
{
|
|
"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
|
|
"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
|
|
}
|
|
)
|
|
|
|
# down_blocks
|
|
for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
|
|
diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
|
|
down_block_prefix = f"encoder.down.{down_block_idx}"
|
|
|
|
# resnets
|
|
for resnet_idx, resnet in enumerate(down_block.resnets):
|
|
diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
|
|
resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
|
|
|
|
diffusers_checkpoint.update(
|
|
vqvae_resnet_to_diffusers_checkpoint(
|
|
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
|
)
|
|
)
|
|
|
|
# downsample
|
|
|
|
# do not include the downsample when on the last down block
|
|
# There is no downsample on the last down block
|
|
if down_block_idx != len(model.encoder.down_blocks) - 1:
|
|
# There's a single downsample in the original checkpoint but a list of downsamples
|
|
# in the diffusers model.
|
|
diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
|
|
downsample_prefix = f"{down_block_prefix}.downsample.conv"
|
|
diffusers_checkpoint.update(
|
|
{
|
|
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
|
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
|
}
|
|
)
|
|
|
|
# attentions
|
|
|
|
if hasattr(down_block, "attentions"):
|
|
for attention_idx, _ in enumerate(down_block.attentions):
|
|
diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
|
|
attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
|
|
diffusers_checkpoint.update(
|
|
vqvae_attention_to_diffusers_checkpoint(
|
|
checkpoint,
|
|
diffusers_attention_prefix=diffusers_attention_prefix,
|
|
attention_prefix=attention_prefix,
|
|
)
|
|
)
|
|
|
|
# mid block
|
|
|
|
# mid block attentions
|
|
|
|
# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
|
|
diffusers_attention_prefix = "encoder.mid_block.attentions.0"
|
|
attention_prefix = "encoder.mid.attn_1"
|
|
diffusers_checkpoint.update(
|
|
vqvae_attention_to_diffusers_checkpoint(
|
|
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
|
)
|
|
)
|
|
|
|
# mid block resnets
|
|
|
|
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
|
diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
|
|
|
|
# the hardcoded prefixes to `block_` are 1 and 2
|
|
orig_resnet_idx = diffusers_resnet_idx + 1
|
|
# There are two hardcoded resnets in the middle of the VQ-diffusion encoder
|
|
resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
|
|
|
|
diffusers_checkpoint.update(
|
|
vqvae_resnet_to_diffusers_checkpoint(
|
|
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
|
)
|
|
)
|
|
|
|
diffusers_checkpoint.update(
|
|
{
|
|
# conv_norm_out
|
|
"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
|
|
"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
|
|
# conv_out
|
|
"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
|
|
"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
|
|
}
|
|
)
|
|
|
|
return diffusers_checkpoint
|
|
|
|
|
|
def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint):
|
|
diffusers_checkpoint = {}
|
|
|
|
# conv in
|
|
diffusers_checkpoint.update(
|
|
{
|
|
"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
|
|
"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
|
|
}
|
|
)
|
|
|
|
# up_blocks
|
|
|
|
for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
|
|
# up_blocks are stored in reverse order in the VQ-diffusion checkpoint
|
|
orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
|
|
|
|
diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
|
|
up_block_prefix = f"decoder.up.{orig_up_block_idx}"
|
|
|
|
# resnets
|
|
for resnet_idx, resnet in enumerate(up_block.resnets):
|
|
diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
|
|
resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
|
|
|
|
diffusers_checkpoint.update(
|
|
vqvae_resnet_to_diffusers_checkpoint(
|
|
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
|
)
|
|
)
|
|
|
|
# upsample
|
|
|
|
# there is no up sample on the last up block
|
|
if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
|
|
# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
|
|
# in the diffusers model.
|
|
diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
|
|
downsample_prefix = f"{up_block_prefix}.upsample.conv"
|
|
diffusers_checkpoint.update(
|
|
{
|
|
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
|
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
|
}
|
|
)
|
|
|
|
# attentions
|
|
|
|
if hasattr(up_block, "attentions"):
|
|
for attention_idx, _ in enumerate(up_block.attentions):
|
|
diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
|
|
attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
|
|
diffusers_checkpoint.update(
|
|
vqvae_attention_to_diffusers_checkpoint(
|
|
checkpoint,
|
|
diffusers_attention_prefix=diffusers_attention_prefix,
|
|
attention_prefix=attention_prefix,
|
|
)
|
|
)
|
|
|
|
# mid block
|
|
|
|
# mid block attentions
|
|
|
|
# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
|
|
diffusers_attention_prefix = "decoder.mid_block.attentions.0"
|
|
attention_prefix = "decoder.mid.attn_1"
|
|
diffusers_checkpoint.update(
|
|
vqvae_attention_to_diffusers_checkpoint(
|
|
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
|
)
|
|
)
|
|
|
|
# mid block resnets
|
|
|
|
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
|
diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
|
|
|
|
# the hardcoded prefixes to `block_` are 1 and 2
|
|
orig_resnet_idx = diffusers_resnet_idx + 1
|
|
# There are two hardcoded resnets in the middle of the VQ-diffusion decoder
|
|
resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
|
|
|
|
diffusers_checkpoint.update(
|
|
vqvae_resnet_to_diffusers_checkpoint(
|
|
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
|
)
|
|
)
|
|
|
|
diffusers_checkpoint.update(
|
|
{
|
|
# conv_norm_out
|
|
"decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"],
|
|
"decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"],
|
|
# conv_out
|
|
"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
|
|
"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
|
|
}
|
|
)
|
|
|
|
return diffusers_checkpoint
|
|
|
|
|
|
def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
|
rv = {
|
|
# norm1
|
|
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
|
|
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
|
|
# conv1
|
|
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
|
|
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
|
|
# norm2
|
|
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
|
|
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
|
|
# conv2
|
|
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
|
|
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
|
|
}
|
|
|
|
if resnet.conv_shortcut is not None:
|
|
rv.update(
|
|
{
|
|
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
|
|
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
|
|
}
|
|
)
|
|
|
|
return rv
|
|
|
|
|
|
def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
|
return {
|
|
# group_norm
|
|
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
|
|
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
|
|
# query
|
|
f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
|
|
f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"],
|
|
# key
|
|
f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
|
|
f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"],
|
|
# value
|
|
f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
|
|
f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"],
|
|
# proj_attn
|
|
f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
|
|
:, :, 0, 0
|
|
],
|
|
f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
|
|
}
|
|
|
|
|
|
# done vqvae checkpoint
|
|
|
|
# transformer model
|
|
|
|
PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"]
|
|
PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"]
|
|
PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"]
|
|
|
|
|
|
def transformer_model_from_original_config(
|
|
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
|
):
|
|
assert (
|
|
original_diffusion_config.target in PORTED_DIFFUSIONS
|
|
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
|
|
assert (
|
|
original_transformer_config.target in PORTED_TRANSFORMERS
|
|
), f"{original_transformer_config.target} has not yet been ported to diffusers."
|
|
assert (
|
|
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
|
|
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
|
|
|
|
original_diffusion_config = original_diffusion_config.params
|
|
original_transformer_config = original_transformer_config.params
|
|
original_content_embedding_config = original_content_embedding_config.params
|
|
|
|
inner_dim = original_transformer_config["n_embd"]
|
|
|
|
n_heads = original_transformer_config["n_head"]
|
|
|
|
# VQ-Diffusion gives dimension of the multi-headed attention layers as the
|
|
# number of attention heads times the sequence length (the dimension) of a
|
|
# single head. We want to specify our attention blocks with those values
|
|
# specified separately
|
|
assert inner_dim % n_heads == 0
|
|
d_head = inner_dim // n_heads
|
|
|
|
depth = original_transformer_config["n_layer"]
|
|
context_dim = original_transformer_config["condition_dim"]
|
|
|
|
num_embed = original_content_embedding_config["num_embed"]
|
|
# the number of embeddings in the transformer includes the mask embedding.
|
|
# the content embedding (the vqvae) does not include the mask embedding.
|
|
num_embed = num_embed + 1
|
|
|
|
height = original_transformer_config["content_spatial_size"][0]
|
|
width = original_transformer_config["content_spatial_size"][1]
|
|
|
|
assert width == height, "width has to be equal to height"
|
|
dropout = original_transformer_config["resid_pdrop"]
|
|
num_embeds_ada_norm = original_diffusion_config["diffusion_step"]
|
|
|
|
model_kwargs = {
|
|
"attention_bias": True,
|
|
"cross_attention_dim": context_dim,
|
|
"attention_head_dim": d_head,
|
|
"num_layers": depth,
|
|
"dropout": dropout,
|
|
"num_attention_heads": n_heads,
|
|
"num_vector_embeds": num_embed,
|
|
"num_embeds_ada_norm": num_embeds_ada_norm,
|
|
"norm_num_groups": 32,
|
|
"sample_size": width,
|
|
"activation_fn": "geglu-approximate",
|
|
}
|
|
|
|
model = Transformer2DModel(**model_kwargs)
|
|
return model
|
|
|
|
|
|
# done transformer model
|
|
|
|
# transformer checkpoint
|
|
|
|
|
|
def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
|
diffusers_checkpoint = {}
|
|
|
|
transformer_prefix = "transformer.transformer"
|
|
|
|
diffusers_latent_image_embedding_prefix = "latent_image_embedding"
|
|
latent_image_embedding_prefix = f"{transformer_prefix}.content_emb"
|
|
|
|
# DalleMaskImageEmbedding
|
|
diffusers_checkpoint.update(
|
|
{
|
|
f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[
|
|
f"{latent_image_embedding_prefix}.emb.weight"
|
|
],
|
|
f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[
|
|
f"{latent_image_embedding_prefix}.height_emb.weight"
|
|
],
|
|
f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[
|
|
f"{latent_image_embedding_prefix}.width_emb.weight"
|
|
],
|
|
}
|
|
)
|
|
|
|
# transformer blocks
|
|
for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks):
|
|
diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}"
|
|
transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}"
|
|
|
|
# ada norm block
|
|
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1"
|
|
ada_norm_prefix = f"{transformer_block_prefix}.ln1"
|
|
|
|
diffusers_checkpoint.update(
|
|
transformer_ada_norm_to_diffusers_checkpoint(
|
|
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
|
|
)
|
|
)
|
|
|
|
# attention block
|
|
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1"
|
|
attention_prefix = f"{transformer_block_prefix}.attn1"
|
|
|
|
diffusers_checkpoint.update(
|
|
transformer_attention_to_diffusers_checkpoint(
|
|
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
|
)
|
|
)
|
|
|
|
# ada norm block
|
|
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2"
|
|
ada_norm_prefix = f"{transformer_block_prefix}.ln1_1"
|
|
|
|
diffusers_checkpoint.update(
|
|
transformer_ada_norm_to_diffusers_checkpoint(
|
|
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
|
|
)
|
|
)
|
|
|
|
# attention block
|
|
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2"
|
|
attention_prefix = f"{transformer_block_prefix}.attn2"
|
|
|
|
diffusers_checkpoint.update(
|
|
transformer_attention_to_diffusers_checkpoint(
|
|
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
|
)
|
|
)
|
|
|
|
# norm block
|
|
diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3"
|
|
norm_block_prefix = f"{transformer_block_prefix}.ln2"
|
|
|
|
diffusers_checkpoint.update(
|
|
{
|
|
f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"],
|
|
f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"],
|
|
}
|
|
)
|
|
|
|
# feedforward block
|
|
diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff"
|
|
feedforward_prefix = f"{transformer_block_prefix}.mlp"
|
|
|
|
diffusers_checkpoint.update(
|
|
transformer_feedforward_to_diffusers_checkpoint(
|
|
checkpoint,
|
|
diffusers_feedforward_prefix=diffusers_feedforward_prefix,
|
|
feedforward_prefix=feedforward_prefix,
|
|
)
|
|
)
|
|
|
|
# to logits
|
|
|
|
diffusers_norm_out_prefix = "norm_out"
|
|
norm_out_prefix = f"{transformer_prefix}.to_logits.0"
|
|
|
|
diffusers_checkpoint.update(
|
|
{
|
|
f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"],
|
|
f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"],
|
|
}
|
|
)
|
|
|
|
diffusers_out_prefix = "out"
|
|
out_prefix = f"{transformer_prefix}.to_logits.1"
|
|
|
|
diffusers_checkpoint.update(
|
|
{
|
|
f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"],
|
|
f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"],
|
|
}
|
|
)
|
|
|
|
return diffusers_checkpoint
|
|
|
|
|
|
def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix):
|
|
return {
|
|
f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"],
|
|
f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"],
|
|
f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"],
|
|
}
|
|
|
|
|
|
def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
|
return {
|
|
# key
|
|
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"],
|
|
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"],
|
|
# query
|
|
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"],
|
|
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"],
|
|
# value
|
|
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"],
|
|
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"],
|
|
# linear out
|
|
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"],
|
|
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"],
|
|
}
|
|
|
|
|
|
def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix):
|
|
return {
|
|
f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"],
|
|
f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"],
|
|
f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"],
|
|
f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"],
|
|
}
|
|
|
|
|
|
# done transformer checkpoint
|
|
|
|
|
|
def read_config_file(filename):
|
|
# The yaml file contains annotations that certain values should
|
|
# loaded as tuples. By default, OmegaConf will panic when reading
|
|
# these. Instead, we can manually read the yaml with the FullLoader and then
|
|
# construct the OmegaConf object.
|
|
with open(filename) as f:
|
|
original_config = yaml.load(f, FullLoader)
|
|
|
|
return OmegaConf.create(original_config)
|
|
|
|
|
|
# We take separate arguments for the vqvae because the ITHQ vqvae config file
|
|
# is separate from the config file for the rest of the model.
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--vqvae_checkpoint_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to the vqvae checkpoint to convert.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--vqvae_original_config_file",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="The YAML config file corresponding to the original architecture for the vqvae.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--original_config_file",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="The YAML config file corresponding to the original architecture.",
|
|
)
|
|
|
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
|
|
|
parser.add_argument(
|
|
"--checkpoint_load_device",
|
|
default="cpu",
|
|
type=str,
|
|
required=False,
|
|
help="The device passed to `map_location` when loading checkpoints.",
|
|
)
|
|
|
|
# See link for how ema weights are always selected
|
|
# https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65
|
|
parser.add_argument(
|
|
"--no_use_ema",
|
|
action="store_true",
|
|
required=False,
|
|
help=(
|
|
"Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set"
|
|
" it as the original VQ-Diffusion always uses the ema weights when loading models."
|
|
),
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
use_ema = not args.no_use_ema
|
|
|
|
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
|
|
|
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
|
|
|
# vqvae_model
|
|
|
|
print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}")
|
|
|
|
vqvae_original_config = read_config_file(args.vqvae_original_config_file).model
|
|
vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"]
|
|
|
|
with init_empty_weights():
|
|
vqvae_model = vqvae_model_from_original_config(vqvae_original_config)
|
|
|
|
vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint)
|
|
|
|
with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file:
|
|
torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name)
|
|
del vqvae_diffusers_checkpoint
|
|
del vqvae_checkpoint
|
|
load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto")
|
|
|
|
print("done loading vqvae")
|
|
|
|
# done vqvae_model
|
|
|
|
# transformer_model
|
|
|
|
print(
|
|
f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:"
|
|
f" {use_ema}"
|
|
)
|
|
|
|
original_config = read_config_file(args.original_config_file).model
|
|
|
|
diffusion_config = original_config.params.diffusion_config
|
|
transformer_config = original_config.params.diffusion_config.params.transformer_config
|
|
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
|
|
|
|
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
|
|
|
|
if use_ema:
|
|
if "ema" in pre_checkpoint:
|
|
checkpoint = {}
|
|
for k, v in pre_checkpoint["model"].items():
|
|
checkpoint[k] = v
|
|
|
|
for k, v in pre_checkpoint["ema"].items():
|
|
# The ema weights are only used on the transformer. To mimic their key as if they came
|
|
# from the state_dict for the top level model, we prefix with an additional "transformer."
|
|
# See the source linked in the args.use_ema config for more information.
|
|
checkpoint[f"transformer.{k}"] = v
|
|
else:
|
|
print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.")
|
|
checkpoint = pre_checkpoint["model"]
|
|
else:
|
|
checkpoint = pre_checkpoint["model"]
|
|
|
|
del pre_checkpoint
|
|
|
|
with init_empty_weights():
|
|
transformer_model = transformer_model_from_original_config(
|
|
diffusion_config, transformer_config, content_embedding_config
|
|
)
|
|
|
|
diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint(
|
|
transformer_model, checkpoint
|
|
)
|
|
|
|
# classifier free sampling embeddings interlude
|
|
|
|
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
|
|
# model, so we pull them off the checkpoint before the checkpoint is deleted.
|
|
|
|
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf
|
|
|
|
if learnable_classifier_free_sampling_embeddings:
|
|
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
|
|
else:
|
|
learned_classifier_free_sampling_embeddings_embeddings = None
|
|
|
|
# done classifier free sampling embeddings interlude
|
|
|
|
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
|
|
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
|
|
del diffusers_transformer_checkpoint
|
|
del checkpoint
|
|
load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto")
|
|
|
|
print("done loading transformer")
|
|
|
|
# done transformer_model
|
|
|
|
# text encoder
|
|
|
|
print("loading CLIP text encoder")
|
|
|
|
clip_name = "openai/clip-vit-base-patch32"
|
|
|
|
# The original VQ-Diffusion specifies the pad value by the int used in the
|
|
# returned tokens. Each model uses `0` as the pad value. The transformers clip api
|
|
# specifies the pad value via the token before it has been tokenized. The `!` pad
|
|
# token is the same as padding with the `0` pad value.
|
|
pad_token = "!"
|
|
|
|
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
|
|
|
|
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
|
|
|
|
text_encoder_model = CLIPTextModel.from_pretrained(
|
|
clip_name,
|
|
# `CLIPTextModel` does not support device_map="auto"
|
|
# device_map="auto"
|
|
)
|
|
|
|
print("done loading CLIP text encoder")
|
|
|
|
# done text encoder
|
|
|
|
# scheduler
|
|
|
|
scheduler_model = VQDiffusionScheduler(
|
|
# the scheduler has the same number of embeddings as the transformer
|
|
num_vec_classes=transformer_model.num_vector_embeds
|
|
)
|
|
|
|
# done scheduler
|
|
|
|
# learned classifier free sampling embeddings
|
|
|
|
with init_empty_weights():
|
|
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
|
|
learnable_classifier_free_sampling_embeddings,
|
|
hidden_size=text_encoder_model.config.hidden_size,
|
|
length=tokenizer_model.model_max_length,
|
|
)
|
|
|
|
learned_classifier_free_sampling_checkpoint = {
|
|
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
|
|
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
|
|
del learned_classifier_free_sampling_checkpoint
|
|
del learned_classifier_free_sampling_embeddings_embeddings
|
|
load_checkpoint_and_dispatch(
|
|
learned_classifier_free_sampling_embeddings_model,
|
|
learned_classifier_free_sampling_checkpoint_file.name,
|
|
device_map="auto",
|
|
)
|
|
|
|
# done learned classifier free sampling embeddings
|
|
|
|
print(f"saving VQ diffusion model, path: {args.dump_path}")
|
|
|
|
pipe = VQDiffusionPipeline(
|
|
vqvae=vqvae_model,
|
|
transformer=transformer_model,
|
|
tokenizer=tokenizer_model,
|
|
text_encoder=text_encoder_model,
|
|
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
|
|
scheduler=scheduler_model,
|
|
)
|
|
pipe.save_pretrained(args.dump_path)
|
|
|
|
print("done writing VQ diffusion model")
|