""" This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers. It currently only supports porting the ITHQ dataset. ITHQ dataset: ```sh # From the root directory of diffusers. # Download the VQVAE checkpoint $ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth # Download the VQVAE config # NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class # `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE` # loads `OUTPUT/pretrained_model/taming_dvae/config.yaml` $ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml # Download the main model checkpoint $ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth # Download the main model config $ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml # run the convert script $ python ./scripts/convert_vq_diffusion_to_diffusers.py \ --checkpoint_path ./ithq_learnable.pth \ --original_config_file ./ithq.yaml \ --vqvae_checkpoint_path ./ithq_vqvae.pth \ --vqvae_original_config_file ./ithq_vqvae.yaml \ --dump_path ``` """ import argparse import tempfile import torch import yaml from accelerate import init_empty_weights, load_checkpoint_and_dispatch from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings from transformers import CLIPTextModel, CLIPTokenizer from yaml.loader import FullLoader try: from omegaconf import OmegaConf except ImportError: raise ImportError( "OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install" " OmegaConf`." ) # vqvae model PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] def vqvae_model_from_original_config(original_config): assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers." original_config = original_config.params original_encoder_config = original_config.encoder_config.params original_decoder_config = original_config.decoder_config.params in_channels = original_encoder_config.in_channels out_channels = original_decoder_config.out_ch down_block_types = get_down_block_types(original_encoder_config) up_block_types = get_up_block_types(original_decoder_config) assert original_encoder_config.ch == original_decoder_config.ch assert original_encoder_config.ch_mult == original_decoder_config.ch_mult block_out_channels = tuple( [original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult] ) assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks layers_per_block = original_encoder_config.num_res_blocks assert original_encoder_config.z_channels == original_decoder_config.z_channels latent_channels = original_encoder_config.z_channels num_vq_embeddings = original_config.n_embed # Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion norm_num_groups = 32 e_dim = original_config.embed_dim model = VQModel( in_channels=in_channels, out_channels=out_channels, down_block_types=down_block_types, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, latent_channels=latent_channels, num_vq_embeddings=num_vq_embeddings, norm_num_groups=norm_num_groups, vq_embed_dim=e_dim, ) return model def get_down_block_types(original_encoder_config): attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions) num_resolutions = len(original_encoder_config.ch_mult) resolution = coerce_resolution(original_encoder_config.resolution) curr_res = resolution down_block_types = [] for _ in range(num_resolutions): if curr_res in attn_resolutions: down_block_type = "AttnDownEncoderBlock2D" else: down_block_type = "DownEncoderBlock2D" down_block_types.append(down_block_type) curr_res = [r // 2 for r in curr_res] return down_block_types def get_up_block_types(original_decoder_config): attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions) num_resolutions = len(original_decoder_config.ch_mult) resolution = coerce_resolution(original_decoder_config.resolution) curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution] up_block_types = [] for _ in reversed(range(num_resolutions)): if curr_res in attn_resolutions: up_block_type = "AttnUpDecoderBlock2D" else: up_block_type = "UpDecoderBlock2D" up_block_types.append(up_block_type) curr_res = [r * 2 for r in curr_res] return up_block_types def coerce_attn_resolutions(attn_resolutions): attn_resolutions = OmegaConf.to_object(attn_resolutions) attn_resolutions_ = [] for ar in attn_resolutions: if isinstance(ar, (list, tuple)): attn_resolutions_.append(list(ar)) else: attn_resolutions_.append([ar, ar]) return attn_resolutions_ def coerce_resolution(resolution): resolution = OmegaConf.to_object(resolution) if isinstance(resolution, int): resolution = [resolution, resolution] # H, W elif isinstance(resolution, (tuple, list)): resolution = list(resolution) else: raise ValueError("Unknown type of resolution:", resolution) return resolution # done vqvae model # vqvae checkpoint def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint = {} diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint)) # quant_conv diffusers_checkpoint.update( { "quant_conv.weight": checkpoint["quant_conv.weight"], "quant_conv.bias": checkpoint["quant_conv.bias"], } ) # quantize diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]}) # post_quant_conv diffusers_checkpoint.update( { "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], } ) # decoder diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint)) return diffusers_checkpoint def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint = {} # conv_in diffusers_checkpoint.update( { "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], } ) # down_blocks for down_block_idx, down_block in enumerate(model.encoder.down_blocks): diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" down_block_prefix = f"encoder.down.{down_block_idx}" # resnets for resnet_idx, resnet in enumerate(down_block.resnets): diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" diffusers_checkpoint.update( vqvae_resnet_to_diffusers_checkpoint( resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix ) ) # downsample # do not include the downsample when on the last down block # There is no downsample on the last down block if down_block_idx != len(model.encoder.down_blocks) - 1: # There's a single downsample in the original checkpoint but a list of downsamples # in the diffusers model. diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" downsample_prefix = f"{down_block_prefix}.downsample.conv" diffusers_checkpoint.update( { f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], } ) # attentions if hasattr(down_block, "attentions"): for attention_idx, _ in enumerate(down_block.attentions): diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" diffusers_checkpoint.update( vqvae_attention_to_diffusers_checkpoint( checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix, ) ) # mid block # mid block attentions # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder diffusers_attention_prefix = "encoder.mid_block.attentions.0" attention_prefix = "encoder.mid.attn_1" diffusers_checkpoint.update( vqvae_attention_to_diffusers_checkpoint( checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix ) ) # mid block resnets for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" # the hardcoded prefixes to `block_` are 1 and 2 orig_resnet_idx = diffusers_resnet_idx + 1 # There are two hardcoded resnets in the middle of the VQ-diffusion encoder resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" diffusers_checkpoint.update( vqvae_resnet_to_diffusers_checkpoint( resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix ) ) diffusers_checkpoint.update( { # conv_norm_out "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], # conv_out "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], } ) return diffusers_checkpoint def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint = {} # conv in diffusers_checkpoint.update( { "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], } ) # up_blocks for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): # up_blocks are stored in reverse order in the VQ-diffusion checkpoint orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" up_block_prefix = f"decoder.up.{orig_up_block_idx}" # resnets for resnet_idx, resnet in enumerate(up_block.resnets): diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" diffusers_checkpoint.update( vqvae_resnet_to_diffusers_checkpoint( resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix ) ) # upsample # there is no up sample on the last up block if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples # in the diffusers model. diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" downsample_prefix = f"{up_block_prefix}.upsample.conv" diffusers_checkpoint.update( { f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], } ) # attentions if hasattr(up_block, "attentions"): for attention_idx, _ in enumerate(up_block.attentions): diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" diffusers_checkpoint.update( vqvae_attention_to_diffusers_checkpoint( checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix, ) ) # mid block # mid block attentions # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder diffusers_attention_prefix = "decoder.mid_block.attentions.0" attention_prefix = "decoder.mid.attn_1" diffusers_checkpoint.update( vqvae_attention_to_diffusers_checkpoint( checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix ) ) # mid block resnets for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" # the hardcoded prefixes to `block_` are 1 and 2 orig_resnet_idx = diffusers_resnet_idx + 1 # There are two hardcoded resnets in the middle of the VQ-diffusion decoder resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" diffusers_checkpoint.update( vqvae_resnet_to_diffusers_checkpoint( resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix ) ) diffusers_checkpoint.update( { # conv_norm_out "decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"], "decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"], # conv_out "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], } ) return diffusers_checkpoint def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): rv = { # norm1 f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], # conv1 f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], # norm2 f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], # conv2 f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], } if resnet.conv_shortcut is not None: rv.update( { f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], } ) return rv def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): return { # group_norm f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], # query f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], # key f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], # value f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], # proj_attn f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ :, :, 0, 0 ], f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], } # done vqvae checkpoint # transformer model PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"] PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"] PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"] def transformer_model_from_original_config( original_diffusion_config, original_transformer_config, original_content_embedding_config ): assert ( original_diffusion_config.target in PORTED_DIFFUSIONS ), f"{original_diffusion_config.target} has not yet been ported to diffusers." assert ( original_transformer_config.target in PORTED_TRANSFORMERS ), f"{original_transformer_config.target} has not yet been ported to diffusers." assert ( original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS ), f"{original_content_embedding_config.target} has not yet been ported to diffusers." original_diffusion_config = original_diffusion_config.params original_transformer_config = original_transformer_config.params original_content_embedding_config = original_content_embedding_config.params inner_dim = original_transformer_config["n_embd"] n_heads = original_transformer_config["n_head"] # VQ-Diffusion gives dimension of the multi-headed attention layers as the # number of attention heads times the sequence length (the dimension) of a # single head. We want to specify our attention blocks with those values # specified separately assert inner_dim % n_heads == 0 d_head = inner_dim // n_heads depth = original_transformer_config["n_layer"] context_dim = original_transformer_config["condition_dim"] num_embed = original_content_embedding_config["num_embed"] # the number of embeddings in the transformer includes the mask embedding. # the content embedding (the vqvae) does not include the mask embedding. num_embed = num_embed + 1 height = original_transformer_config["content_spatial_size"][0] width = original_transformer_config["content_spatial_size"][1] assert width == height, "width has to be equal to height" dropout = original_transformer_config["resid_pdrop"] num_embeds_ada_norm = original_diffusion_config["diffusion_step"] model_kwargs = { "attention_bias": True, "cross_attention_dim": context_dim, "attention_head_dim": d_head, "num_layers": depth, "dropout": dropout, "num_attention_heads": n_heads, "num_vector_embeds": num_embed, "num_embeds_ada_norm": num_embeds_ada_norm, "norm_num_groups": 32, "sample_size": width, "activation_fn": "geglu-approximate", } model = Transformer2DModel(**model_kwargs) return model # done transformer model # transformer checkpoint def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint = {} transformer_prefix = "transformer.transformer" diffusers_latent_image_embedding_prefix = "latent_image_embedding" latent_image_embedding_prefix = f"{transformer_prefix}.content_emb" # DalleMaskImageEmbedding diffusers_checkpoint.update( { f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[ f"{latent_image_embedding_prefix}.emb.weight" ], f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[ f"{latent_image_embedding_prefix}.height_emb.weight" ], f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[ f"{latent_image_embedding_prefix}.width_emb.weight" ], } ) # transformer blocks for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks): diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}" transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}" # ada norm block diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1" ada_norm_prefix = f"{transformer_block_prefix}.ln1" diffusers_checkpoint.update( transformer_ada_norm_to_diffusers_checkpoint( checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix ) ) # attention block diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1" attention_prefix = f"{transformer_block_prefix}.attn1" diffusers_checkpoint.update( transformer_attention_to_diffusers_checkpoint( checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix ) ) # ada norm block diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2" ada_norm_prefix = f"{transformer_block_prefix}.ln1_1" diffusers_checkpoint.update( transformer_ada_norm_to_diffusers_checkpoint( checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix ) ) # attention block diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2" attention_prefix = f"{transformer_block_prefix}.attn2" diffusers_checkpoint.update( transformer_attention_to_diffusers_checkpoint( checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix ) ) # norm block diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3" norm_block_prefix = f"{transformer_block_prefix}.ln2" diffusers_checkpoint.update( { f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"], f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"], } ) # feedforward block diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff" feedforward_prefix = f"{transformer_block_prefix}.mlp" diffusers_checkpoint.update( transformer_feedforward_to_diffusers_checkpoint( checkpoint, diffusers_feedforward_prefix=diffusers_feedforward_prefix, feedforward_prefix=feedforward_prefix, ) ) # to logits diffusers_norm_out_prefix = "norm_out" norm_out_prefix = f"{transformer_prefix}.to_logits.0" diffusers_checkpoint.update( { f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"], f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"], } ) diffusers_out_prefix = "out" out_prefix = f"{transformer_prefix}.to_logits.1" diffusers_checkpoint.update( { f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"], f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"], } ) return diffusers_checkpoint def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix): return { f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"], f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"], f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"], } def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): return { # key f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"], f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"], # query f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"], f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"], # value f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"], f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"], # linear out f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"], f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"], } def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix): return { f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"], f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"], f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"], f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"], } # done transformer checkpoint def read_config_file(filename): # The yaml file contains annotations that certain values should # loaded as tuples. By default, OmegaConf will panic when reading # these. Instead, we can manually read the yaml with the FullLoader and then # construct the OmegaConf object. with open(filename) as f: original_config = yaml.load(f, FullLoader) return OmegaConf.create(original_config) # We take separate arguments for the vqvae because the ITHQ vqvae config file # is separate from the config file for the rest of the model. if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--vqvae_checkpoint_path", default=None, type=str, required=True, help="Path to the vqvae checkpoint to convert.", ) parser.add_argument( "--vqvae_original_config_file", default=None, type=str, required=True, help="The YAML config file corresponding to the original architecture for the vqvae.", ) parser.add_argument( "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." ) parser.add_argument( "--original_config_file", default=None, type=str, required=True, help="The YAML config file corresponding to the original architecture.", ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument( "--checkpoint_load_device", default="cpu", type=str, required=False, help="The device passed to `map_location` when loading checkpoints.", ) # See link for how ema weights are always selected # https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65 parser.add_argument( "--no_use_ema", action="store_true", required=False, help=( "Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set" " it as the original VQ-Diffusion always uses the ema weights when loading models." ), ) args = parser.parse_args() use_ema = not args.no_use_ema print(f"loading checkpoints to {args.checkpoint_load_device}") checkpoint_map_location = torch.device(args.checkpoint_load_device) # vqvae_model print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}") vqvae_original_config = read_config_file(args.vqvae_original_config_file).model vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"] with init_empty_weights(): vqvae_model = vqvae_model_from_original_config(vqvae_original_config) vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint) with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file: torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name) del vqvae_diffusers_checkpoint del vqvae_checkpoint load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto") print("done loading vqvae") # done vqvae_model # transformer_model print( f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:" f" {use_ema}" ) original_config = read_config_file(args.original_config_file).model diffusion_config = original_config.params.diffusion_config transformer_config = original_config.params.diffusion_config.params.transformer_config content_embedding_config = original_config.params.diffusion_config.params.content_emb_config pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location) if use_ema: if "ema" in pre_checkpoint: checkpoint = {} for k, v in pre_checkpoint["model"].items(): checkpoint[k] = v for k, v in pre_checkpoint["ema"].items(): # The ema weights are only used on the transformer. To mimic their key as if they came # from the state_dict for the top level model, we prefix with an additional "transformer." # See the source linked in the args.use_ema config for more information. checkpoint[f"transformer.{k}"] = v else: print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.") checkpoint = pre_checkpoint["model"] else: checkpoint = pre_checkpoint["model"] del pre_checkpoint with init_empty_weights(): transformer_model = transformer_model_from_original_config( diffusion_config, transformer_config, content_embedding_config ) diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint( transformer_model, checkpoint ) # classifier free sampling embeddings interlude # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate # model, so we pull them off the checkpoint before the checkpoint is deleted. learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf if learnable_classifier_free_sampling_embeddings: learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] else: learned_classifier_free_sampling_embeddings_embeddings = None # done classifier free sampling embeddings interlude with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) del diffusers_transformer_checkpoint del checkpoint load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto") print("done loading transformer") # done transformer_model # text encoder print("loading CLIP text encoder") clip_name = "openai/clip-vit-base-patch32" # The original VQ-Diffusion specifies the pad value by the int used in the # returned tokens. Each model uses `0` as the pad value. The transformers clip api # specifies the pad value via the token before it has been tokenized. The `!` pad # token is the same as padding with the `0` pad value. pad_token = "!" tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 text_encoder_model = CLIPTextModel.from_pretrained( clip_name, # `CLIPTextModel` does not support device_map="auto" # device_map="auto" ) print("done loading CLIP text encoder") # done text encoder # scheduler scheduler_model = VQDiffusionScheduler( # the scheduler has the same number of embeddings as the transformer num_vec_classes=transformer_model.num_vector_embeds ) # done scheduler # learned classifier free sampling embeddings with init_empty_weights(): learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings( learnable_classifier_free_sampling_embeddings, hidden_size=text_encoder_model.config.hidden_size, length=tokenizer_model.model_max_length, ) learned_classifier_free_sampling_checkpoint = { "embeddings": learned_classifier_free_sampling_embeddings_embeddings.float() } with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file: torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name) del learned_classifier_free_sampling_checkpoint del learned_classifier_free_sampling_embeddings_embeddings load_checkpoint_and_dispatch( learned_classifier_free_sampling_embeddings_model, learned_classifier_free_sampling_checkpoint_file.name, device_map="auto", ) # done learned classifier free sampling embeddings print(f"saving VQ diffusion model, path: {args.dump_path}") pipe = VQDiffusionPipeline( vqvae=vqvae_model, transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model, learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model, scheduler=scheduler_model, ) pipe.save_pretrained(args.dump_path) print("done writing VQ diffusion model")