From 89793a97e2c65016c02382f51e786ff52aff9be9 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 25 Aug 2022 15:46:09 +0200 Subject: [PATCH] Style the `scripts` directory (#250) Style scripts --- Makefile | 2 +- .../change_naming_configs_and_checkpoints.py | 11 +- scripts/conversion_ldm_uncond.py | 10 +- ...t_ddpm_original_checkpoint_to_diffusers.py | 334 +++++++++++------- ...rt_ldm_original_checkpoint_to_diffusers.py | 220 +++++++----- ...ncsnpp_original_checkpoint_to_diffusers.py | 4 +- scripts/generate_logits.py | 176 +++++---- 7 files changed, 447 insertions(+), 310 deletions(-) diff --git a/Makefile b/Makefile index fa346868..6e513e2e 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src -check_dirs := examples tests src utils +check_dirs := examples scripts src tests utils modified_only_fixup: $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) diff --git a/scripts/change_naming_configs_and_checkpoints.py b/scripts/change_naming_configs_and_checkpoints.py index 20e1d5c7..756bdccc 100644 --- a/scripts/change_naming_configs_and_checkpoints.py +++ b/scripts/change_naming_configs_and_checkpoints.py @@ -15,12 +15,15 @@ """ Conversion script for the LDM checkpoints. """ import argparse -import os import json +import os + import torch -from diffusers import UNet2DModel, UNet2DConditionModel + +from diffusers import UNet2DConditionModel, UNet2DModel from transformers.file_utils import has_file + do_only_config = False do_only_weights = True do_only_renaming = False @@ -37,9 +40,7 @@ if __name__ == "__main__": help="The config json file corresponding to the architecture.", ) - parser.add_argument( - "--dump_path", default=None, type=str, required=True, help="Path to the output model." - ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") args = parser.parse_args() diff --git a/scripts/conversion_ldm_uncond.py b/scripts/conversion_ldm_uncond.py index 0957c2ed..67edd638 100644 --- a/scripts/conversion_ldm_uncond.py +++ b/scripts/conversion_ldm_uncond.py @@ -1,9 +1,10 @@ import argparse -import OmegaConf import torch -from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler +import OmegaConf +from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel + def convert_ldm_original(checkpoint_path, config_path, output_path): config = OmegaConf.load(config_path) @@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path): for key in keys: if key.startswith(first_stage_key): first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key] - + # extract state_dict for UNetLDM unet_state_dict = {} unet_key = "model.diffusion_model." for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = state_dict[key] - + vqvae_init_args = config.model.params.first_stage_config.params unet_init_args = config.model.params.unet_config.params @@ -53,4 +54,3 @@ if __name__ == "__main__": args = parser.parse_args() convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path) - diff --git a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py index 88cd92d8..52d75c75 100644 --- a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py @@ -1,31 +1,33 @@ -from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL import argparse import json + import torch +from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel + def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. """ if n_shave_prefix_segments >= 0: - return '.'.join(path.split('.')[n_shave_prefix_segments:]) + return ".".join(path.split(".")[n_shave_prefix_segments:]) else: - return '.'.join(path.split('.')[:n_shave_prefix_segments]) + return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): mapping = [] for old_item in old_list: new_item = old_item - new_item = new_item.replace('block.', 'resnets.') - new_item = new_item.replace('conv_shorcut', 'conv1') - new_item = new_item.replace('nin_shortcut', 'conv_shortcut') - new_item = new_item.replace('temb_proj', 'time_emb_proj') + new_item = new_item.replace("block.", "resnets.") + new_item = new_item.replace("conv_shorcut", "conv1") + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = new_item.replace("temb_proj", "time_emb_proj") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) return mapping @@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False): # In `model.mid`, the layer is called `attn`. if not in_mid: - new_item = new_item.replace('attn', 'attentions') - new_item = new_item.replace('.k.', '.key.') - new_item = new_item.replace('.v.', '.value.') - new_item = new_item.replace('.q.', '.query.') + new_item = new_item.replace("attn", "attentions") + new_item = new_item.replace(".k.", ".key.") + new_item = new_item.replace(".v.", ".value.") + new_item = new_item.replace(".q.", ".query.") - new_item = new_item.replace('proj_out', 'proj_attn') - new_item = new_item.replace('norm', 'group_norm') + new_item = new_item.replace("proj_out", "proj_attn") + new_item = new_item.replace("norm", "group_norm") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) return mapping -def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None): +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." if attention_paths_to_split is not None: @@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) - checkpoint[path_map['query']] = query.reshape(target_shape).squeeze() - checkpoint[path_map['key']] = key.reshape(target_shape).squeeze() - checkpoint[path_map['value']] = value.reshape(target_shape).squeeze() + checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze() + checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze() + checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze() for path in paths: - new_path = path['new'] + new_path = path["new"] if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue - new_path = new_path.replace('down.', 'down_blocks.') - new_path = new_path.replace('up.', 'up_blocks.') + new_path = new_path.replace("down.", "down_blocks.") + new_path = new_path.replace("up.", "up_blocks.") if additional_replacements is not None: for replacement in additional_replacements: - new_path = new_path.replace(replacement['old'], replacement['new']) + new_path = new_path.replace(replacement["old"], replacement["new"]) - if 'attentions' in new_path: - checkpoint[new_path] = old_checkpoint[path['old']].squeeze() + if "attentions" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]].squeeze() else: - checkpoint[new_path] = old_checkpoint[path['old']] + checkpoint[new_path] = old_checkpoint[path["old"]] def convert_ddpm_checkpoint(checkpoint, config): @@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config): """ new_checkpoint = {} - new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight'] - new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['temb.dense.0.bias'] - new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['temb.dense.1.weight'] - new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['temb.dense.1.bias'] + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"] - new_checkpoint['conv_norm_out.weight'] = checkpoint['norm_out.weight'] - new_checkpoint['conv_norm_out.bias'] = checkpoint['norm_out.bias'] + new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"] - new_checkpoint['conv_in.weight'] = checkpoint['conv_in.weight'] - new_checkpoint['conv_in.bias'] = checkpoint['conv_in.bias'] - new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight'] - new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias'] + new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"] - num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer}) - down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)} + num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer}) + down_blocks = { + layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } - num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer}) - up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} + num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer}) + up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} for i in range(num_down_blocks): - block_id = (i - 1) // (config['layers_per_block'] + 1) + block_id = (i - 1) // (config["layers_per_block"] + 1) - if any('downsample' in layer for layer in down_blocks[i]): - new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight'] - new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias'] -# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] -# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] + if any("downsample" in layer for layer in down_blocks[i]): + new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[ + f"down.{i}.downsample.op.weight" + ] + new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"] + # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] + # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] - if any('block' in layer for layer in down_blocks[i]): - num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer}) - blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} + if any("block" in layer for layer in down_blocks[i]): + num_blocks = len( + {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer} + ) + blocks = { + layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key] + for layer_id in range(num_blocks) + } if num_blocks > 0: - for j in range(config['layers_per_block']): + for j in range(config["layers_per_block"]): paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint) - if any('attn' in layer for layer in down_blocks[i]): - num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer}) - attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} + if any("attn" in layer for layer in down_blocks[i]): + num_attn = len( + {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer} + ) + attns = { + layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key] + for layer_id in range(num_blocks) + } if num_attn > 0: - for j in range(config['layers_per_block']): + for j in range(config["layers_per_block"]): paths = renew_attention_paths(attns[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) @@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config): # Mid new 2 paths = renew_resnet_paths(mid_block_1_layers) - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ - {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'} - ]) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}], + ) paths = renew_resnet_paths(mid_block_2_layers) - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ - {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'} - ]) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}], + ) paths = renew_attention_paths(mid_attn_1_layers, in_mid=True) - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ - {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'} - ]) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}], + ) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i - if any('upsample' in layer for layer in up_blocks[i]): - new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight'] - new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias'] + if any("upsample" in layer for layer in up_blocks[i]): + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[ + f"up.{i}.upsample.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"] - if any('block' in layer for layer in up_blocks[i]): - num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer}) - blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} + if any("block" in layer for layer in up_blocks[i]): + num_blocks = len( + {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer} + ) + blocks = { + layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks) + } if num_blocks > 0: - for j in range(config['layers_per_block'] + 1): - replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} + for j in range(config["layers_per_block"] + 1): + replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"} paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) - if any('attn' in layer for layer in up_blocks[i]): - num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer}) - attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} + if any("attn" in layer for layer in up_blocks[i]): + num_attn = len( + {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer} + ) + attns = { + layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks) + } if num_attn > 0: - for j in range(config['layers_per_block'] + 1): - replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} + for j in range(config["layers_per_block"] + 1): + replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"} paths = renew_attention_paths(attns[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) - new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()} + new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()} return new_checkpoint @@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config): """ new_checkpoint = {} - new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight'] - new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias'] + new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"] - new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight'] - new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias'] - new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight'] - new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias'] + new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"] - new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight'] - new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias'] + new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"] - new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight'] - new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias'] - new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight'] - new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias'] + new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"] - num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer}) - down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)} + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer}) + down_blocks = { + layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } - num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer}) - up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer}) + up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} for i in range(num_down_blocks): - block_id = (i - 1) // (config['layers_per_block'] + 1) + block_id = (i - 1) // (config["layers_per_block"] + 1) - if any('downsample' in layer for layer in down_blocks[i]): - new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight'] - new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias'] + if any("downsample" in layer for layer in down_blocks[i]): + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[ + f"encoder.down.{i}.downsample.conv.weight" + ] + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[ + f"encoder.down.{i}.downsample.conv.bias" + ] - if any('block' in layer for layer in down_blocks[i]): - num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer}) - blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} + if any("block" in layer for layer in down_blocks[i]): + num_blocks = len( + {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer} + ) + blocks = { + layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key] + for layer_id in range(num_blocks) + } if num_blocks > 0: - for j in range(config['layers_per_block']): + for j in range(config["layers_per_block"]): paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint) - if any('attn' in layer for layer in down_blocks[i]): - num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer}) - attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} + if any("attn" in layer for layer in down_blocks[i]): + num_attn = len( + {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer} + ) + attns = { + layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key] + for layer_id in range(num_blocks) + } if num_attn > 0: - for j in range(config['layers_per_block']): + for j in range(config["layers_per_block"]): paths = renew_attention_paths(attns[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) @@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config): # Mid new 2 paths = renew_resnet_paths(mid_block_1_layers) - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ - {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'} - ]) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}], + ) paths = renew_resnet_paths(mid_block_2_layers) - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ - {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'} - ]) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}], + ) paths = renew_attention_paths(mid_attn_1_layers, in_mid=True) - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ - {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'} - ]) + assign_to_checkpoint( + paths, + new_checkpoint, + checkpoint, + additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}], + ) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i - if any('upsample' in layer for layer in up_blocks[i]): - new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight'] - new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias'] + if any("upsample" in layer for layer in up_blocks[i]): + new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[ + f"decoder.up.{i}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[ + f"decoder.up.{i}.upsample.conv.bias" + ] - if any('block' in layer for layer in up_blocks[i]): - num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer}) - blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} + if any("block" in layer for layer in up_blocks[i]): + num_blocks = len( + {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer} + ) + blocks = { + layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks) + } if num_blocks > 0: - for j in range(config['layers_per_block'] + 1): - replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} + for j in range(config["layers_per_block"] + 1): + replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"} paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) - if any('attn' in layer for layer in up_blocks[i]): - num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer}) - attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} + if any("attn" in layer for layer in up_blocks[i]): + num_attn = len( + {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer} + ) + attns = { + layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks) + } if num_attn > 0: - for j in range(config['layers_per_block'] + 1): - replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} + for j in range(config["layers_per_block"] + 1): + replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"} paths = renew_attention_paths(attns[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) - new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()} + new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()} new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"] if "quantize.embedding.weight" in checkpoint: @@ -321,9 +395,7 @@ if __name__ == "__main__": help="The config json file corresponding to the architecture.", ) - parser.add_argument( - "--dump_path", default=None, type=str, required=True, help="Path to the output model." - ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") args = parser.parse_args() checkpoint = torch.load(args.checkpoint_path) diff --git a/scripts/convert_ldm_original_checkpoint_to_diffusers.py b/scripts/convert_ldm_original_checkpoint_to_diffusers.py index f4f8331e..52865792 100644 --- a/scripts/convert_ldm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ldm_original_checkpoint_to_diffusers.py @@ -16,8 +16,10 @@ import argparse import json + import torch -from diffusers import VQModel, DDPMScheduler, UNet2DModel, LDMPipeline + +from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel def shave_segments(path, n_shave_prefix_segments=1): @@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1): Removes segments. Positive values shave the first segments, negative shave the last segments. """ if n_shave_prefix_segments >= 0: - return '.'.join(path.split('.')[n_shave_prefix_segments:]) + return ".".join(path.split(".")[n_shave_prefix_segments:]) else: - return '.'.join(path.split('.')[:n_shave_prefix_segments]) + return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): @@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0): """ mapping = [] for old_item in old_list: - new_item = old_item.replace('in_layers.0', 'norm1') - new_item = new_item.replace('in_layers.2', 'conv1') + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") - new_item = new_item.replace('out_layers.0', 'norm2') - new_item = new_item.replace('out_layers.3', 'conv2') + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") - new_item = new_item.replace('emb_layers.1', 'time_emb_proj') - new_item = new_item.replace('skip_connection', 'conv_shortcut') + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) return mapping @@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0): for old_item in old_list: new_item = old_item - new_item = new_item.replace('norm.weight', 'group_norm.weight') - new_item = new_item.replace('norm.bias', 'group_norm.bias') + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) return mapping -def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None): +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements @@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) - checkpoint[path_map['query']] = query.reshape(target_shape) - checkpoint[path_map['key']] = key.reshape(target_shape) - checkpoint[path_map['value']] = value.reshape(target_shape) + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) for path in paths: - new_path = path['new'] + new_path = path["new"] # These have already been assigned if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue # Global renaming happens here - new_path = new_path.replace('middle_block.0', 'mid.resnets.0') - new_path = new_path.replace('middle_block.1', 'mid.attentions.0') - new_path = new_path.replace('middle_block.2', 'mid.resnets.1') + new_path = new_path.replace("middle_block.0", "mid.resnets.0") + new_path = new_path.replace("middle_block.1", "mid.attentions.0") + new_path = new_path.replace("middle_block.2", "mid.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: - new_path = new_path.replace(replacement['old'], replacement['new']) + new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0] + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] else: - checkpoint[new_path] = old_checkpoint[path['old']] + checkpoint[new_path] = old_checkpoint[path["old"]] def convert_ldm_checkpoint(checkpoint, config): @@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config): """ new_checkpoint = {} - new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['time_embed.0.weight'] - new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['time_embed.0.bias'] - new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['time_embed.2.weight'] - new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['time_embed.2.bias'] + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] - new_checkpoint['conv_in.weight'] = checkpoint['input_blocks.0.0.weight'] - new_checkpoint['conv_in.bias'] = checkpoint['input_blocks.0.0.bias'] + new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] - new_checkpoint['conv_norm_out.weight'] = checkpoint['out.0.weight'] - new_checkpoint['conv_norm_out.bias'] = checkpoint['out.0.bias'] - new_checkpoint['conv_out.weight'] = checkpoint['out.2.weight'] - new_checkpoint['conv_out.bias'] = checkpoint['out.2.bias'] + new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"] + new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"] + new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"] # Retrieves the keys for the input blocks only - num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'input_blocks' in layer}) - input_blocks = {layer_id: [key for key in checkpoint if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)} + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } # Retrieves the keys for the middle blocks only - num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'middle_block' in layer}) - middle_blocks = {layer_id: [key for key in checkpoint if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)} + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } # Retrieves the keys for the output blocks only - num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'output_blocks' in layer}) - output_blocks = {layer_id: [key for key in checkpoint if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)} + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } for i in range(1, num_input_blocks): - block_id = (i - 1) // (config['num_res_blocks'] + 1) - layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1) + block_id = (i - 1) // (config["num_res_blocks"] + 1) + layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1) - resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key] - attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key] + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - if f'input_blocks.{i}.0.op.weight' in checkpoint: - new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.weight'] = checkpoint[f'input_blocks.{i}.0.op.weight'] - new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.bias'] = checkpoint[f'input_blocks.{i}.0.op.bias'] + if f"input_blocks.{i}.0.op.weight" in checkpoint: + new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[ + f"input_blocks.{i}.0.op.weight" + ] + new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[ + f"input_blocks.{i}.0.op.bias" + ] paths = renew_resnet_paths(resnets) - meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'} - resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'} - assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"} + resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"} + assign_to_checkpoint( + paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config + ) if len(attentions): paths = renew_attention_paths(attentions) - meta_path = {'old': f'input_blocks.{i}.1', 'new': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}'} + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}", + } to_split = { - f'input_blocks.{i}.1.qkv.bias': { - 'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias', - 'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias', - 'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias', + f"input_blocks.{i}.1.qkv.bias": { + "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", + "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", + "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", }, - f'input_blocks.{i}.1.qkv.weight': { - 'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight', - 'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight', - 'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight', + f"input_blocks.{i}.1.qkv.weight": { + "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", + "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", + "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", }, } assign_to_checkpoint( @@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config): checkpoint, additional_replacements=[meta_path], attention_paths_to_split=to_split, - config=config + config=config, ) resnet_0 = middle_blocks[0] @@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config): attentions_paths = renew_attention_paths(attentions) to_split = { - 'middle_block.1.qkv.bias': { - 'key': 'mid_block.attentions.0.key.bias', - 'query': 'mid_block.attentions.0.query.bias', - 'value': 'mid_block.attentions.0.value.bias', + "middle_block.1.qkv.bias": { + "key": "mid_block.attentions.0.key.bias", + "query": "mid_block.attentions.0.query.bias", + "value": "mid_block.attentions.0.value.bias", }, - 'middle_block.1.qkv.weight': { - 'key': 'mid_block.attentions.0.key.weight', - 'query': 'mid_block.attentions.0.query.weight', - 'value': 'mid_block.attentions.0.value.weight', + "middle_block.1.qkv.weight": { + "key": "mid_block.attentions.0.key.weight", + "query": "mid_block.attentions.0.query.weight", + "value": "mid_block.attentions.0.value.weight", }, } - assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config) + assign_to_checkpoint( + attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config + ) for i in range(num_output_blocks): - block_id = i // (config['num_res_blocks'] + 1) - layer_in_block_id = i % (config['num_res_blocks'] + 1) + block_id = i // (config["num_res_blocks"] + 1) + layer_in_block_id = i % (config["num_res_blocks"] + 1) output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] output_block_list = {} for layer in output_block_layers: - layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1) + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) if layer_id in output_block_list: output_block_list[layer_id].append(layer_name) else: output_block_list[layer_id] = [layer_name] if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key] - attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key] + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) - meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'} + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config) - if ['conv.weight', 'conv.bias'] in output_block_list.values(): - index = list(output_block_list.values()).index(['conv.weight', 'conv.bias']) - new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight'] - new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias'] + if ["conv.weight", "conv.bias"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[ + f"output_blocks.{i}.{index}.conv.bias" + ] # Clear attentions as they have been attributed above. if len(attentions) == 2: @@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config): if len(attentions): paths = renew_attention_paths(attentions) meta_path = { - 'old': f'output_blocks.{i}.1', - 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}' + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } to_split = { - f'output_blocks.{i}.1.qkv.bias': { - 'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias', - 'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias', - 'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias', + f"output_blocks.{i}.1.qkv.bias": { + "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", + "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", + "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", }, - f'output_blocks.{i}.1.qkv.weight': { - 'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight', - 'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight', - 'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight', + f"output_blocks.{i}.1.qkv.weight": { + "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", + "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", + "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", }, } assign_to_checkpoint( @@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config): new_checkpoint, checkpoint, additional_replacements=[meta_path], - attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None, + attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None, config=config, ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: - old_path = '.'.join(['output_blocks', str(i), path['old']]) - new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']]) + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_checkpoint[new_path] = checkpoint[old_path] @@ -303,9 +331,7 @@ if __name__ == "__main__": help="The config json file corresponding to the architecture.", ) - parser.add_argument( - "--dump_path", default=None, type=str, required=True, help="Path to the output model." - ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") args = parser.parse_args() diff --git a/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py b/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py index 8f02d691..271359b8 100644 --- a/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py @@ -16,8 +16,10 @@ import argparse import json + import torch -from diffusers import UNet2DModel + +from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel def convert_ncsnpp_checkpoint(checkpoint, config): diff --git a/scripts/generate_logits.py b/scripts/generate_logits.py index 4dbe30f7..61851212 100644 --- a/scripts/generate_logits.py +++ b/scripts/generate_logits.py @@ -1,91 +1,127 @@ -from huggingface_hub import HfApi -from transformers.file_utils import has_file -from diffusers import UNet2DModel import random + import torch + +from diffusers import UNet2DModel +from huggingface_hub import HfApi + + api = HfApi() results = {} -results["google_ddpm_cifar10_32"] = torch.tensor([-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467, - 1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189, - -1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839, - 0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557]) -results["google_ddpm_ema_bedroom_256"] = torch.tensor([-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436, - 1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208, - -2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948, - 2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365]) -results["CompVis_ldm_celebahq_256"] = torch.tensor([-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869, - -0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304, - -0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925, - 0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943]) -results["google_ncsnpp_ffhq_1024"] = torch.tensor([ 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172, - -0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309, - 0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805, - -0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505]) -results["google_ncsnpp_bedroom_256"] = torch.tensor([ 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133, - -0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395, - 0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559, - -0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386]) -results["google_ncsnpp_celebahq_256"] = torch.tensor([ 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078, - -0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330, - 0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683, - -0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431]) -results["google_ncsnpp_church_256"] = torch.tensor([ 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042, - -0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398, - 0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574, - -0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390]) -results["google_ncsnpp_ffhq_256"] = torch.tensor([ 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042, - -0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290, - 0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746, - -0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473]) -results["google_ddpm_cat_256"] = torch.tensor([-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330, - 1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243, - -2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810, - 1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251]) -results["google_ddpm_celebahq_256"] = torch.tensor([-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324, - 0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181, - -2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259, - 1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266]) -results["google_ddpm_ema_celebahq_256"] = torch.tensor([-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212, - 0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027, - -2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131, - 1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355]) -results["google_ddpm_church_256"] = torch.tensor([-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959, - 1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351, - -3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341, - 3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066]) -results["google_ddpm_bedroom_256"] = torch.tensor([-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740, - 1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398, - -2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395, - 2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243]) -results["google_ddpm_ema_church_256"] = torch.tensor([-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336, - 1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908, - -3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560, - 3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343]) -results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344, - 1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391, - -2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439, - 1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219]) +# fmt: off +results["google_ddpm_cifar10_32"] = torch.tensor([ + -0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467, + 1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189, + -1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839, + 0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557 +]) +results["google_ddpm_ema_bedroom_256"] = torch.tensor([ + -2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436, + 1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208, + -2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948, + 2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365 +]) +results["CompVis_ldm_celebahq_256"] = torch.tensor([ + -0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869, + -0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304, + -0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925, + 0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943 +]) +results["google_ncsnpp_ffhq_1024"] = torch.tensor([ + 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172, + -0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309, + 0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805, + -0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505 +]) +results["google_ncsnpp_bedroom_256"] = torch.tensor([ + 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133, + -0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395, + 0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559, + -0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386 +]) +results["google_ncsnpp_celebahq_256"] = torch.tensor([ + 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078, + -0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330, + 0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683, + -0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431 +]) +results["google_ncsnpp_church_256"] = torch.tensor([ + 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042, + -0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398, + 0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574, + -0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390 +]) +results["google_ncsnpp_ffhq_256"] = torch.tensor([ + 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042, + -0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290, + 0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746, + -0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473 +]) +results["google_ddpm_cat_256"] = torch.tensor([ + -1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330, + 1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243, + -2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810, + 1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251]) +results["google_ddpm_celebahq_256"] = torch.tensor([ + -1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324, + 0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181, + -2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259, + 1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266 +]) +results["google_ddpm_ema_celebahq_256"] = torch.tensor([ + -1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212, + 0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027, + -2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131, + 1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355 +]) +results["google_ddpm_church_256"] = torch.tensor([ + -2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959, + 1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351, + -3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341, + 3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066 +]) +results["google_ddpm_bedroom_256"] = torch.tensor([ + -2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740, + 1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398, + -2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395, + 2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243 +]) +results["google_ddpm_ema_church_256"] = torch.tensor([ + -2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336, + 1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908, + -3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560, + 3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343 +]) +results["google_ddpm_ema_cat_256"] = torch.tensor([ + -1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344, + 1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391, + -2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439, + 1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219 +]) +# fmt: on models = api.list_models(filter="diffusers") for mod in models: - if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256": + if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256": local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1] print(f"Started running {mod.modelId}!!!") if mod.modelId.startswith("CompVis"): - model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet") - else: + model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet") + else: model = UNet2DModel.from_pretrained(local_checkpoint) - + torch.manual_seed(0) random.seed(0) - + noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) time_step = torch.tensor([10] * noise.shape[0]) with torch.no_grad(): - logits = model(noise, time_step)['sample'] + logits = model(noise, time_step)["sample"] - assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3) + assert torch.allclose( + logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3 + ) print(f"{mod.modelId} has passed succesfully!!!")