Style the `scripts` directory (#250)

Style scripts
This commit is contained in:
Anton Lozhkov 2022-08-25 15:46:09 +02:00 committed by GitHub
parent 365f75233f
commit 89793a97e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 447 additions and 310 deletions

View File

@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src export PYTHONPATH = src
check_dirs := examples tests src utils check_dirs := examples scripts src tests utils
modified_only_fixup: modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))

View File

@ -15,12 +15,15 @@
""" Conversion script for the LDM checkpoints. """ """ Conversion script for the LDM checkpoints. """
import argparse import argparse
import os
import json import json
import os
import torch import torch
from diffusers import UNet2DModel, UNet2DConditionModel
from diffusers import UNet2DConditionModel, UNet2DModel
from transformers.file_utils import has_file from transformers.file_utils import has_file
do_only_config = False do_only_config = False
do_only_weights = True do_only_weights = True
do_only_renaming = False do_only_renaming = False
@ -37,9 +40,7 @@ if __name__ == "__main__":
help="The config json file corresponding to the architecture.", help="The config json file corresponding to the architecture.",
) )
parser.add_argument( parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,9 +1,10 @@
import argparse import argparse
import OmegaConf
import torch import torch
from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler import OmegaConf
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
def convert_ldm_original(checkpoint_path, config_path, output_path): def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path) config = OmegaConf.load(config_path)
@ -53,4 +54,3 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path) convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)

View File

@ -1,31 +1,33 @@
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
import argparse import argparse
import json import json
import torch import torch
from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
""" """
Removes segments. Positive values shave the first segments, negative shave the last segments. Removes segments. Positive values shave the first segments, negative shave the last segments.
""" """
if n_shave_prefix_segments >= 0: if n_shave_prefix_segments >= 0:
return '.'.join(path.split('.')[n_shave_prefix_segments:]) return ".".join(path.split(".")[n_shave_prefix_segments:])
else: else:
return '.'.join(path.split('.')[:n_shave_prefix_segments]) return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0): def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
mapping = [] mapping = []
for old_item in old_list: for old_item in old_list:
new_item = old_item new_item = old_item
new_item = new_item.replace('block.', 'resnets.') new_item = new_item.replace("block.", "resnets.")
new_item = new_item.replace('conv_shorcut', 'conv1') new_item = new_item.replace("conv_shorcut", "conv1")
new_item = new_item.replace('nin_shortcut', 'conv_shortcut') new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = new_item.replace('temb_proj', 'time_emb_proj') new_item = new_item.replace("temb_proj", "time_emb_proj")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item}) mapping.append({"old": old_item, "new": new_item})
return mapping return mapping
@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
# In `model.mid`, the layer is called `attn`. # In `model.mid`, the layer is called `attn`.
if not in_mid: if not in_mid:
new_item = new_item.replace('attn', 'attentions') new_item = new_item.replace("attn", "attentions")
new_item = new_item.replace('.k.', '.key.') new_item = new_item.replace(".k.", ".key.")
new_item = new_item.replace('.v.', '.value.') new_item = new_item.replace(".v.", ".value.")
new_item = new_item.replace('.q.', '.query.') new_item = new_item.replace(".q.", ".query.")
new_item = new_item.replace('proj_out', 'proj_attn') new_item = new_item.replace("proj_out", "proj_attn")
new_item = new_item.replace('norm', 'group_norm') new_item = new_item.replace("norm", "group_norm")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item}) mapping.append({"old": old_item, "new": new_item})
return mapping return mapping
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None): def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
if attention_paths_to_split is not None: if attention_paths_to_split is not None:
@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1) query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map['query']] = query.reshape(target_shape).squeeze() checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
checkpoint[path_map['key']] = key.reshape(target_shape).squeeze() checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
checkpoint[path_map['value']] = value.reshape(target_shape).squeeze() checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
for path in paths: for path in paths:
new_path = path['new'] new_path = path["new"]
if attention_paths_to_split is not None and new_path in attention_paths_to_split: if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue continue
new_path = new_path.replace('down.', 'down_blocks.') new_path = new_path.replace("down.", "down_blocks.")
new_path = new_path.replace('up.', 'up_blocks.') new_path = new_path.replace("up.", "up_blocks.")
if additional_replacements is not None: if additional_replacements is not None:
for replacement in additional_replacements: for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new']) new_path = new_path.replace(replacement["old"], replacement["new"])
if 'attentions' in new_path: if "attentions" in new_path:
checkpoint[new_path] = old_checkpoint[path['old']].squeeze() checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
else: else:
checkpoint[new_path] = old_checkpoint[path['old']] checkpoint[new_path] = old_checkpoint[path["old"]]
def convert_ddpm_checkpoint(checkpoint, config): def convert_ddpm_checkpoint(checkpoint, config):
@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config):
""" """
new_checkpoint = {} new_checkpoint = {}
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight'] new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['temb.dense.0.bias'] new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['temb.dense.1.weight'] new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['temb.dense.1.bias'] new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
new_checkpoint['conv_norm_out.weight'] = checkpoint['norm_out.weight'] new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
new_checkpoint['conv_norm_out.bias'] = checkpoint['norm_out.bias'] new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
new_checkpoint['conv_in.weight'] = checkpoint['conv_in.weight'] new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
new_checkpoint['conv_in.bias'] = checkpoint['conv_in.bias'] new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight'] new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias'] new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer}) num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)} down_blocks = {
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer}) num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
for i in range(num_down_blocks): for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1) block_id = (i - 1) // (config["layers_per_block"] + 1)
if any('downsample' in layer for layer in down_blocks[i]): if any("downsample" in layer for layer in down_blocks[i]):
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight'] new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias'] f"down.{i}.downsample.op.weight"
]
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if any('block' in layer for layer in down_blocks[i]): if any("block" in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer}) num_blocks = len(
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_blocks > 0: if num_blocks > 0:
for j in range(config['layers_per_block']): for j in range(config["layers_per_block"]):
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint) assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in down_blocks[i]): if any("attn" in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer}) num_attn = len(
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_attn > 0: if num_attn > 0:
for j in range(config['layers_per_block']): for j in range(config["layers_per_block"]):
paths = renew_attention_paths(attns[j]) paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config):
# Mid new 2 # Mid new 2
paths = renew_resnet_paths(mid_block_1_layers) paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ assign_to_checkpoint(
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'} paths,
]) new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
)
paths = renew_resnet_paths(mid_block_2_layers) paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ assign_to_checkpoint(
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'} paths,
]) new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
)
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True) paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ assign_to_checkpoint(
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'} paths,
]) new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
)
for i in range(num_up_blocks): for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in up_blocks[i]): if any("upsample" in layer for layer in up_blocks[i]):
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight'] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias'] f"up.{i}.upsample.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
if any('block' in layer for layer in up_blocks[i]): if any("block" in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer}) num_blocks = len(
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_blocks > 0: if num_blocks > 0:
for j in range(config['layers_per_block'] + 1): for j in range(config["layers_per_block"] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in up_blocks[i]): if any("attn" in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer}) num_attn = len(
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_attn > 0: if num_attn > 0:
for j in range(config['layers_per_block'] + 1): for j in range(config["layers_per_block"] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_attention_paths(attns[j]) paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()} new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
return new_checkpoint return new_checkpoint
@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
""" """
new_checkpoint = {} new_checkpoint = {}
new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight'] new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias'] new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight'] new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias'] new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight'] new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias'] new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight'] new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias'] new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight'] new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias'] new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight'] new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias'] new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer}) num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)} down_blocks = {
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer}) num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
for i in range(num_down_blocks): for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1) block_id = (i - 1) // (config["layers_per_block"] + 1)
if any('downsample' in layer for layer in down_blocks[i]): if any("downsample" in layer for layer in down_blocks[i]):
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight'] new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias'] f"encoder.down.{i}.downsample.conv.weight"
]
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
f"encoder.down.{i}.downsample.conv.bias"
]
if any('block' in layer for layer in down_blocks[i]): if any("block" in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer}) num_blocks = len(
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_blocks > 0: if num_blocks > 0:
for j in range(config['layers_per_block']): for j in range(config["layers_per_block"]):
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint) assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in down_blocks[i]): if any("attn" in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer}) num_attn = len(
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_attn > 0: if num_attn > 0:
for j in range(config['layers_per_block']): for j in range(config["layers_per_block"]):
paths = renew_attention_paths(attns[j]) paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
# Mid new 2 # Mid new 2
paths = renew_resnet_paths(mid_block_1_layers) paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ assign_to_checkpoint(
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'} paths,
]) new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
)
paths = renew_resnet_paths(mid_block_2_layers) paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ assign_to_checkpoint(
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'} paths,
]) new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
)
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True) paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[ assign_to_checkpoint(
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'} paths,
]) new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
)
for i in range(num_up_blocks): for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in up_blocks[i]): if any("upsample" in layer for layer in up_blocks[i]):
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight'] new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias'] f"decoder.up.{i}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
f"decoder.up.{i}.upsample.conv.bias"
]
if any('block' in layer for layer in up_blocks[i]): if any("block" in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer}) num_blocks = len(
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_blocks > 0: if num_blocks > 0:
for j in range(config['layers_per_block'] + 1): for j in range(config["layers_per_block"] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in up_blocks[i]): if any("attn" in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer}) num_attn = len(
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_attn > 0: if num_attn > 0:
for j in range(config['layers_per_block'] + 1): for j in range(config["layers_per_block"] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_attention_paths(attns[j]) paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()} new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"] new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"] new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
if "quantize.embedding.weight" in checkpoint: if "quantize.embedding.weight" in checkpoint:
@ -321,9 +395,7 @@ if __name__ == "__main__":
help="The config json file corresponding to the architecture.", help="The config json file corresponding to the architecture.",
) )
parser.add_argument( parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)
args = parser.parse_args() args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path) checkpoint = torch.load(args.checkpoint_path)

View File

@ -16,8 +16,10 @@
import argparse import argparse
import json import json
import torch import torch
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LDMPipeline
from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1):
Removes segments. Positive values shave the first segments, negative shave the last segments. Removes segments. Positive values shave the first segments, negative shave the last segments.
""" """
if n_shave_prefix_segments >= 0: if n_shave_prefix_segments >= 0:
return '.'.join(path.split('.')[n_shave_prefix_segments:]) return ".".join(path.split(".")[n_shave_prefix_segments:])
else: else:
return '.'.join(path.split('.')[:n_shave_prefix_segments]) return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0): def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
""" """
mapping = [] mapping = []
for old_item in old_list: for old_item in old_list:
new_item = old_item.replace('in_layers.0', 'norm1') new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace('in_layers.2', 'conv1') new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace('out_layers.0', 'norm2') new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace('out_layers.3', 'conv2') new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace('emb_layers.1', 'time_emb_proj') new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace('skip_connection', 'conv_shortcut') new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item}) mapping.append({"old": old_item, "new": new_item})
return mapping return mapping
@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
for old_item in old_list: for old_item in old_list:
new_item = old_item new_item = old_item
new_item = new_item.replace('norm.weight', 'group_norm.weight') new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace('norm.bias', 'group_norm.bias') new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item}) mapping.append({"old": old_item, "new": new_item})
return mapping return mapping
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None): def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
""" """
This does the final conversion step: take locally converted weights and apply a global renaming This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements to them. It splits attention layers, and takes into account additional replacements
@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1) query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map['query']] = query.reshape(target_shape) checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map['key']] = key.reshape(target_shape) checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map['value']] = value.reshape(target_shape) checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths: for path in paths:
new_path = path['new'] new_path = path["new"]
# These have already been assigned # These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split: if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue continue
# Global renaming happens here # Global renaming happens here
new_path = new_path.replace('middle_block.0', 'mid.resnets.0') new_path = new_path.replace("middle_block.0", "mid.resnets.0")
new_path = new_path.replace('middle_block.1', 'mid.attentions.0') new_path = new_path.replace("middle_block.1", "mid.attentions.0")
new_path = new_path.replace('middle_block.2', 'mid.resnets.1') new_path = new_path.replace("middle_block.2", "mid.resnets.1")
if additional_replacements is not None: if additional_replacements is not None:
for replacement in additional_replacements: for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new']) new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear # proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path: if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0] checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else: else:
checkpoint[new_path] = old_checkpoint[path['old']] checkpoint[new_path] = old_checkpoint[path["old"]]
def convert_ldm_checkpoint(checkpoint, config): def convert_ldm_checkpoint(checkpoint, config):
@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config):
""" """
new_checkpoint = {} new_checkpoint = {}
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['time_embed.0.weight'] new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['time_embed.0.bias'] new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['time_embed.2.weight'] new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['time_embed.2.bias'] new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
new_checkpoint['conv_in.weight'] = checkpoint['input_blocks.0.0.weight'] new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
new_checkpoint['conv_in.bias'] = checkpoint['input_blocks.0.0.bias'] new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
new_checkpoint['conv_norm_out.weight'] = checkpoint['out.0.weight'] new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
new_checkpoint['conv_norm_out.bias'] = checkpoint['out.0.bias'] new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
new_checkpoint['conv_out.weight'] = checkpoint['out.2.weight'] new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
new_checkpoint['conv_out.bias'] = checkpoint['out.2.bias'] new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
# Retrieves the keys for the input blocks only # Retrieves the keys for the input blocks only
num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'input_blocks' in layer}) num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
input_blocks = {layer_id: [key for key in checkpoint if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)} input_blocks = {
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only # Retrieves the keys for the middle blocks only
num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'middle_block' in layer}) num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
middle_blocks = {layer_id: [key for key in checkpoint if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)} middle_blocks = {
layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only # Retrieves the keys for the output blocks only
num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'output_blocks' in layer}) num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
output_blocks = {layer_id: [key for key in checkpoint if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)} output_blocks = {
layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks): for i in range(1, num_input_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1) block_id = (i - 1) // (config["num_res_blocks"] + 1)
layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1) layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key] resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f'input_blocks.{i}.0.op.weight' in checkpoint: if f"input_blocks.{i}.0.op.weight" in checkpoint:
new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.weight'] = checkpoint[f'input_blocks.{i}.0.op.weight'] new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.bias'] = checkpoint[f'input_blocks.{i}.0.op.bias'] f"input_blocks.{i}.0.op.weight"
]
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
f"input_blocks.{i}.0.op.bias"
]
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'} meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"}
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'} resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config) assign_to_checkpoint(
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
)
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = {'old': f'input_blocks.{i}.1', 'new': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}'} meta_path = {
"old": f"input_blocks.{i}.1",
"new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}",
}
to_split = { to_split = {
f'input_blocks.{i}.1.qkv.bias': { f"input_blocks.{i}.1.qkv.bias": {
'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias', "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias', "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias', "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
}, },
f'input_blocks.{i}.1.qkv.weight': { f"input_blocks.{i}.1.qkv.weight": {
'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight', "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight', "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight', "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
}, },
} }
assign_to_checkpoint( assign_to_checkpoint(
@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config):
checkpoint, checkpoint,
additional_replacements=[meta_path], additional_replacements=[meta_path],
attention_paths_to_split=to_split, attention_paths_to_split=to_split,
config=config config=config,
) )
resnet_0 = middle_blocks[0] resnet_0 = middle_blocks[0]
@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths = renew_attention_paths(attentions) attentions_paths = renew_attention_paths(attentions)
to_split = { to_split = {
'middle_block.1.qkv.bias': { "middle_block.1.qkv.bias": {
'key': 'mid_block.attentions.0.key.bias', "key": "mid_block.attentions.0.key.bias",
'query': 'mid_block.attentions.0.query.bias', "query": "mid_block.attentions.0.query.bias",
'value': 'mid_block.attentions.0.value.bias', "value": "mid_block.attentions.0.value.bias",
}, },
'middle_block.1.qkv.weight': { "middle_block.1.qkv.weight": {
'key': 'mid_block.attentions.0.key.weight', "key": "mid_block.attentions.0.key.weight",
'query': 'mid_block.attentions.0.query.weight', "query": "mid_block.attentions.0.query.weight",
'value': 'mid_block.attentions.0.value.weight', "value": "mid_block.attentions.0.value.weight",
}, },
} }
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config) assign_to_checkpoint(
attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
)
for i in range(num_output_blocks): for i in range(num_output_blocks):
block_id = i // (config['num_res_blocks'] + 1) block_id = i // (config["num_res_blocks"] + 1)
layer_in_block_id = i % (config['num_res_blocks'] + 1) layer_in_block_id = i % (config["num_res_blocks"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {} output_block_list = {}
for layer in output_block_layers: for layer in output_block_layers:
layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1) layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list: if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name) output_block_list[layer_id].append(layer_name)
else: else:
output_block_list[layer_id] = [layer_name] output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1: if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key] resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key] attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = renew_resnet_paths(resnets) resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'} meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
if ['conv.weight', 'conv.bias'] in output_block_list.values(): if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias']) index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight'] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias'] f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above. # Clear attentions as they have been attributed above.
if len(attentions) == 2: if len(attentions) == 2:
@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config):
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = { meta_path = {
'old': f'output_blocks.{i}.1', "old": f"output_blocks.{i}.1",
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}' "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
} }
to_split = { to_split = {
f'output_blocks.{i}.1.qkv.bias': { f"output_blocks.{i}.1.qkv.bias": {
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias', "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias', "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias', "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
}, },
f'output_blocks.{i}.1.qkv.weight': { f"output_blocks.{i}.1.qkv.weight": {
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight', "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight', "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight', "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
}, },
} }
assign_to_checkpoint( assign_to_checkpoint(
@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint, new_checkpoint,
checkpoint, checkpoint,
additional_replacements=[meta_path], additional_replacements=[meta_path],
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None, attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
config=config, config=config,
) )
else: else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths: for path in resnet_0_paths:
old_path = '.'.join(['output_blocks', str(i), path['old']]) old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']]) new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = checkpoint[old_path] new_checkpoint[new_path] = checkpoint[old_path]
@ -303,9 +331,7 @@ if __name__ == "__main__":
help="The config json file corresponding to the architecture.", help="The config json file corresponding to the architecture.",
) )
parser.add_argument( parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -16,8 +16,10 @@
import argparse import argparse
import json import json
import torch import torch
from diffusers import UNet2DModel
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
def convert_ncsnpp_checkpoint(checkpoint, config): def convert_ncsnpp_checkpoint(checkpoint, config):

View File

@ -1,71 +1,105 @@
from huggingface_hub import HfApi
from transformers.file_utils import has_file
from diffusers import UNet2DModel
import random import random
import torch import torch
from diffusers import UNet2DModel
from huggingface_hub import HfApi
api = HfApi() api = HfApi()
results = {} results = {}
results["google_ddpm_cifar10_32"] = torch.tensor([-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467, # fmt: off
results["google_ddpm_cifar10_32"] = torch.tensor([
-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189, 1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839, -1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557]) 0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557
results["google_ddpm_ema_bedroom_256"] = torch.tensor([-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436, ])
results["google_ddpm_ema_bedroom_256"] = torch.tensor([
-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208, 1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948, -2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365]) 2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365
results["CompVis_ldm_celebahq_256"] = torch.tensor([-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869, ])
results["CompVis_ldm_celebahq_256"] = torch.tensor([
-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304, -0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925, -0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943]) 0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943
results["google_ncsnpp_ffhq_1024"] = torch.tensor([ 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172, ])
results["google_ncsnpp_ffhq_1024"] = torch.tensor([
0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309, -0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805, 0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505]) -0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505
results["google_ncsnpp_bedroom_256"] = torch.tensor([ 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133, ])
results["google_ncsnpp_bedroom_256"] = torch.tensor([
0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395, -0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559, 0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386]) -0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386
results["google_ncsnpp_celebahq_256"] = torch.tensor([ 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078, ])
results["google_ncsnpp_celebahq_256"] = torch.tensor([
0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330, -0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683, 0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431]) -0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431
results["google_ncsnpp_church_256"] = torch.tensor([ 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042, ])
results["google_ncsnpp_church_256"] = torch.tensor([
0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398, -0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574, 0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390]) -0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390
results["google_ncsnpp_ffhq_256"] = torch.tensor([ 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042, ])
results["google_ncsnpp_ffhq_256"] = torch.tensor([
0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290, -0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746, 0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473]) -0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473
results["google_ddpm_cat_256"] = torch.tensor([-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330, ])
results["google_ddpm_cat_256"] = torch.tensor([
-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243, 1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810, -2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251]) 1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
results["google_ddpm_celebahq_256"] = torch.tensor([-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324, results["google_ddpm_celebahq_256"] = torch.tensor([
-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181, 0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259, -2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266]) 1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266
results["google_ddpm_ema_celebahq_256"] = torch.tensor([-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212, ])
results["google_ddpm_ema_celebahq_256"] = torch.tensor([
-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027, 0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131, -2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355]) 1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355
results["google_ddpm_church_256"] = torch.tensor([-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959, ])
results["google_ddpm_church_256"] = torch.tensor([
-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351, 1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341, -3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066]) 3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066
results["google_ddpm_bedroom_256"] = torch.tensor([-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740, ])
results["google_ddpm_bedroom_256"] = torch.tensor([
-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398, 1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395, -2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243]) 2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243
results["google_ddpm_ema_church_256"] = torch.tensor([-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336, ])
results["google_ddpm_ema_church_256"] = torch.tensor([
-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908, 1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560, -3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343]) 3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343
results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344, ])
results["google_ddpm_ema_cat_256"] = torch.tensor([
-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391, 1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439, -2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219]) 1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219
])
# fmt: on
models = api.list_models(filter="diffusers") models = api.list_models(filter="diffusers")
for mod in models: for mod in models:
@ -85,7 +119,9 @@ for mod in models:
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0]) time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad(): with torch.no_grad():
logits = model(noise, time_step)['sample'] logits = model(noise, time_step)["sample"]
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3) assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
)
print(f"{mod.modelId} has passed succesfully!!!") print(f"{mod.modelId} has passed succesfully!!!")