Style the `scripts` directory (#250)

Style scripts
This commit is contained in:
Anton Lozhkov 2022-08-25 15:46:09 +02:00 committed by GitHub
parent 365f75233f
commit 89793a97e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 447 additions and 310 deletions

View File

@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src
check_dirs := examples tests src utils
check_dirs := examples scripts src tests utils
modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))

View File

@ -15,12 +15,15 @@
""" Conversion script for the LDM checkpoints. """
import argparse
import os
import json
import os
import torch
from diffusers import UNet2DModel, UNet2DConditionModel
from diffusers import UNet2DConditionModel, UNet2DModel
from transformers.file_utils import has_file
do_only_config = False
do_only_weights = True
do_only_renaming = False
@ -37,9 +40,7 @@ if __name__ == "__main__":
help="The config json file corresponding to the architecture.",
)
parser.add_argument(
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()

View File

@ -1,9 +1,10 @@
import argparse
import OmegaConf
import torch
from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler
import OmegaConf
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path)
@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
for key in keys:
if key.startswith(first_stage_key):
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
# extract state_dict for UNetLDM
unet_state_dict = {}
unet_key = "model.diffusion_model."
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config.model.params.first_stage_config.params
unet_init_args = config.model.params.unet_config.params
@ -53,4 +54,3 @@ if __name__ == "__main__":
args = parser.parse_args()
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)

View File

@ -1,31 +1,33 @@
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
import argparse
import json
import torch
from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return '.'.join(path.split('.')[n_shave_prefix_segments:])
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return '.'.join(path.split('.')[:n_shave_prefix_segments])
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace('block.', 'resnets.')
new_item = new_item.replace('conv_shorcut', 'conv1')
new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
new_item = new_item.replace('temb_proj', 'time_emb_proj')
new_item = new_item.replace("block.", "resnets.")
new_item = new_item.replace("conv_shorcut", "conv1")
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = new_item.replace("temb_proj", "time_emb_proj")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item})
mapping.append({"old": old_item, "new": new_item})
return mapping
@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
# In `model.mid`, the layer is called `attn`.
if not in_mid:
new_item = new_item.replace('attn', 'attentions')
new_item = new_item.replace('.k.', '.key.')
new_item = new_item.replace('.v.', '.value.')
new_item = new_item.replace('.q.', '.query.')
new_item = new_item.replace("attn", "attentions")
new_item = new_item.replace(".k.", ".key.")
new_item = new_item.replace(".v.", ".value.")
new_item = new_item.replace(".q.", ".query.")
new_item = new_item.replace('proj_out', 'proj_attn')
new_item = new_item.replace('norm', 'group_norm')
new_item = new_item.replace("proj_out", "proj_attn")
new_item = new_item.replace("norm", "group_norm")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item})
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
if attention_paths_to_split is not None:
@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map['query']] = query.reshape(target_shape).squeeze()
checkpoint[path_map['key']] = key.reshape(target_shape).squeeze()
checkpoint[path_map['value']] = value.reshape(target_shape).squeeze()
checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
for path in paths:
new_path = path['new']
new_path = path["new"]
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
new_path = new_path.replace('down.', 'down_blocks.')
new_path = new_path.replace('up.', 'up_blocks.')
new_path = new_path.replace("down.", "down_blocks.")
new_path = new_path.replace("up.", "up_blocks.")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new'])
new_path = new_path.replace(replacement["old"], replacement["new"])
if 'attentions' in new_path:
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
if "attentions" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
else:
checkpoint[new_path] = old_checkpoint[path['old']]
checkpoint[new_path] = old_checkpoint[path["old"]]
def convert_ddpm_checkpoint(checkpoint, config):
@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config):
"""
new_checkpoint = {}
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight']
new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['temb.dense.0.bias']
new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['temb.dense.1.weight']
new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['temb.dense.1.bias']
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
new_checkpoint['conv_norm_out.weight'] = checkpoint['norm_out.weight']
new_checkpoint['conv_norm_out.bias'] = checkpoint['norm_out.bias']
new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
new_checkpoint['conv_in.weight'] = checkpoint['conv_in.weight']
new_checkpoint['conv_in.bias'] = checkpoint['conv_in.bias']
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight']
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias']
new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
down_blocks = {
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1)
block_id = (i - 1) // (config["layers_per_block"] + 1)
if any('downsample' in layer for layer in down_blocks[i]):
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight']
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if any("downsample" in layer for layer in down_blocks[i]):
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
f"down.{i}.downsample.op.weight"
]
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if any('block' in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any("block" in layer for layer in down_blocks[i]):
num_blocks = len(
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_blocks > 0:
for j in range(config['layers_per_block']):
for j in range(config["layers_per_block"]):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any("attn" in layer for layer in down_blocks[i]):
num_attn = len(
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_attn > 0:
for j in range(config['layers_per_block']):
for j in range(config["layers_per_block"]):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config):
# Mid new 2
paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
])
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
)
paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
])
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
)
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
if any("upsample" in layer for layer in up_blocks[i]):
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
f"up.{i}.upsample.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any("block" in layer for layer in up_blocks[i]):
num_blocks = len(
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_blocks > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
for j in range(config["layers_per_block"] + 1):
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any("attn" in layer for layer in up_blocks[i]):
num_attn = len(
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_attn > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
for j in range(config["layers_per_block"] + 1):
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
return new_checkpoint
@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
"""
new_checkpoint = {}
new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight']
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias']
new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight']
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias']
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight']
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias']
new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight']
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias']
new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight']
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias']
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight']
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias']
new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
down_blocks = {
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1)
block_id = (i - 1) // (config["layers_per_block"] + 1)
if any('downsample' in layer for layer in down_blocks[i]):
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight']
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias']
if any("downsample" in layer for layer in down_blocks[i]):
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
f"encoder.down.{i}.downsample.conv.weight"
]
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
f"encoder.down.{i}.downsample.conv.bias"
]
if any('block' in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any("block" in layer for layer in down_blocks[i]):
num_blocks = len(
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_blocks > 0:
for j in range(config['layers_per_block']):
for j in range(config["layers_per_block"]):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any("attn" in layer for layer in down_blocks[i]):
num_attn = len(
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_attn > 0:
for j in range(config['layers_per_block']):
for j in range(config["layers_per_block"]):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
# Mid new 2
paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
])
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
)
paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
])
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
)
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight']
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias']
if any("upsample" in layer for layer in up_blocks[i]):
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
f"decoder.up.{i}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
f"decoder.up.{i}.upsample.conv.bias"
]
if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any("block" in layer for layer in up_blocks[i]):
num_blocks = len(
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_blocks > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
for j in range(config["layers_per_block"] + 1):
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any("attn" in layer for layer in up_blocks[i]):
num_attn = len(
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_attn > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
for j in range(config["layers_per_block"] + 1):
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
if "quantize.embedding.weight" in checkpoint:
@ -321,9 +395,7 @@ if __name__ == "__main__":
help="The config json file corresponding to the architecture.",
)
parser.add_argument(
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path)

View File

@ -16,8 +16,10 @@
import argparse
import json
import torch
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LDMPipeline
from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
def shave_segments(path, n_shave_prefix_segments=1):
@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1):
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return '.'.join(path.split('.')[n_shave_prefix_segments:])
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return '.'.join(path.split('.')[:n_shave_prefix_segments])
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace('in_layers.0', 'norm1')
new_item = new_item.replace('in_layers.2', 'conv1')
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace('out_layers.0', 'norm2')
new_item = new_item.replace('out_layers.3', 'conv2')
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
new_item = new_item.replace('skip_connection', 'conv_shortcut')
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item})
mapping.append({"old": old_item, "new": new_item})
return mapping
@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
for old_item in old_list:
new_item = old_item
new_item = new_item.replace('norm.weight', 'group_norm.weight')
new_item = new_item.replace('norm.bias', 'group_norm.bias')
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item})
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map['query']] = query.reshape(target_shape)
checkpoint[path_map['key']] = key.reshape(target_shape)
checkpoint[path_map['value']] = value.reshape(target_shape)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path['new']
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace('middle_block.0', 'mid.resnets.0')
new_path = new_path.replace('middle_block.1', 'mid.attentions.0')
new_path = new_path.replace('middle_block.2', 'mid.resnets.1')
new_path = new_path.replace("middle_block.0", "mid.resnets.0")
new_path = new_path.replace("middle_block.1", "mid.attentions.0")
new_path = new_path.replace("middle_block.2", "mid.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new'])
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path['old']]
checkpoint[new_path] = old_checkpoint[path["old"]]
def convert_ldm_checkpoint(checkpoint, config):
@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config):
"""
new_checkpoint = {}
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['time_embed.0.weight']
new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['time_embed.0.bias']
new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['time_embed.2.weight']
new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['time_embed.2.bias']
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
new_checkpoint['conv_in.weight'] = checkpoint['input_blocks.0.0.weight']
new_checkpoint['conv_in.bias'] = checkpoint['input_blocks.0.0.bias']
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
new_checkpoint['conv_norm_out.weight'] = checkpoint['out.0.weight']
new_checkpoint['conv_norm_out.bias'] = checkpoint['out.0.bias']
new_checkpoint['conv_out.weight'] = checkpoint['out.2.weight']
new_checkpoint['conv_out.bias'] = checkpoint['out.2.bias']
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'input_blocks' in layer})
input_blocks = {layer_id: [key for key in checkpoint if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)}
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'middle_block' in layer})
middle_blocks = {layer_id: [key for key in checkpoint if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)}
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'output_blocks' in layer})
output_blocks = {layer_id: [key for key in checkpoint if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)}
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1)
block_id = (i - 1) // (config["num_res_blocks"] + 1)
layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key]
attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key]
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f'input_blocks.{i}.0.op.weight' in checkpoint:
new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.weight'] = checkpoint[f'input_blocks.{i}.0.op.weight']
new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.bias'] = checkpoint[f'input_blocks.{i}.0.op.bias']
if f"input_blocks.{i}.0.op.weight" in checkpoint:
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
f"input_blocks.{i}.0.op.weight"
]
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
f"input_blocks.{i}.0.op.bias"
]
paths = renew_resnet_paths(resnets)
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"}
resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
assign_to_checkpoint(
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {'old': f'input_blocks.{i}.1', 'new': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}'}
meta_path = {
"old": f"input_blocks.{i}.1",
"new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}",
}
to_split = {
f'input_blocks.{i}.1.qkv.bias': {
'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
f"input_blocks.{i}.1.qkv.bias": {
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
},
f'input_blocks.{i}.1.qkv.weight': {
'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
f"input_blocks.{i}.1.qkv.weight": {
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
},
}
assign_to_checkpoint(
@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config):
checkpoint,
additional_replacements=[meta_path],
attention_paths_to_split=to_split,
config=config
config=config,
)
resnet_0 = middle_blocks[0]
@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths = renew_attention_paths(attentions)
to_split = {
'middle_block.1.qkv.bias': {
'key': 'mid_block.attentions.0.key.bias',
'query': 'mid_block.attentions.0.query.bias',
'value': 'mid_block.attentions.0.value.bias',
"middle_block.1.qkv.bias": {
"key": "mid_block.attentions.0.key.bias",
"query": "mid_block.attentions.0.query.bias",
"value": "mid_block.attentions.0.value.bias",
},
'middle_block.1.qkv.weight': {
'key': 'mid_block.attentions.0.key.weight',
'query': 'mid_block.attentions.0.query.weight',
'value': 'mid_block.attentions.0.value.weight',
"middle_block.1.qkv.weight": {
"key": "mid_block.attentions.0.key.weight",
"query": "mid_block.attentions.0.query.weight",
"value": "mid_block.attentions.0.value.weight",
},
}
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
assign_to_checkpoint(
attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
)
for i in range(num_output_blocks):
block_id = i // (config['num_res_blocks'] + 1)
layer_in_block_id = i % (config['num_res_blocks'] + 1)
block_id = i // (config["num_res_blocks"] + 1)
layer_in_block_id = i % (config["num_res_blocks"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1)
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key]
attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key]
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
if ['conv.weight', 'conv.bias'] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config):
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
'old': f'output_blocks.{i}.1',
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
to_split = {
f'output_blocks.{i}.1.qkv.bias': {
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
f"output_blocks.{i}.1.qkv.bias": {
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
},
f'output_blocks.{i}.1.qkv.weight': {
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
f"output_blocks.{i}.1.qkv.weight": {
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
},
}
assign_to_checkpoint(
@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint,
checkpoint,
additional_replacements=[meta_path],
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None,
attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
config=config,
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = '.'.join(['output_blocks', str(i), path['old']])
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = checkpoint[old_path]
@ -303,9 +331,7 @@ if __name__ == "__main__":
help="The config json file corresponding to the architecture.",
)
parser.add_argument(
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()

View File

@ -16,8 +16,10 @@
import argparse
import json
import torch
from diffusers import UNet2DModel
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
def convert_ncsnpp_checkpoint(checkpoint, config):

View File

@ -1,91 +1,127 @@
from huggingface_hub import HfApi
from transformers.file_utils import has_file
from diffusers import UNet2DModel
import random
import torch
from diffusers import UNet2DModel
from huggingface_hub import HfApi
api = HfApi()
results = {}
results["google_ddpm_cifar10_32"] = torch.tensor([-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557])
results["google_ddpm_ema_bedroom_256"] = torch.tensor([-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365])
results["CompVis_ldm_celebahq_256"] = torch.tensor([-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943])
results["google_ncsnpp_ffhq_1024"] = torch.tensor([ 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505])
results["google_ncsnpp_bedroom_256"] = torch.tensor([ 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386])
results["google_ncsnpp_celebahq_256"] = torch.tensor([ 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431])
results["google_ncsnpp_church_256"] = torch.tensor([ 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390])
results["google_ncsnpp_ffhq_256"] = torch.tensor([ 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473])
results["google_ddpm_cat_256"] = torch.tensor([-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
results["google_ddpm_celebahq_256"] = torch.tensor([-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266])
results["google_ddpm_ema_celebahq_256"] = torch.tensor([-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355])
results["google_ddpm_church_256"] = torch.tensor([-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066])
results["google_ddpm_bedroom_256"] = torch.tensor([-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243])
results["google_ddpm_ema_church_256"] = torch.tensor([-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343])
results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219])
# fmt: off
results["google_ddpm_cifar10_32"] = torch.tensor([
-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557
])
results["google_ddpm_ema_bedroom_256"] = torch.tensor([
-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365
])
results["CompVis_ldm_celebahq_256"] = torch.tensor([
-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943
])
results["google_ncsnpp_ffhq_1024"] = torch.tensor([
0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505
])
results["google_ncsnpp_bedroom_256"] = torch.tensor([
0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386
])
results["google_ncsnpp_celebahq_256"] = torch.tensor([
0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431
])
results["google_ncsnpp_church_256"] = torch.tensor([
0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390
])
results["google_ncsnpp_ffhq_256"] = torch.tensor([
0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473
])
results["google_ddpm_cat_256"] = torch.tensor([
-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
results["google_ddpm_celebahq_256"] = torch.tensor([
-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266
])
results["google_ddpm_ema_celebahq_256"] = torch.tensor([
-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355
])
results["google_ddpm_church_256"] = torch.tensor([
-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066
])
results["google_ddpm_bedroom_256"] = torch.tensor([
-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243
])
results["google_ddpm_ema_church_256"] = torch.tensor([
-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343
])
results["google_ddpm_ema_cat_256"] = torch.tensor([
-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219
])
# fmt: on
models = api.list_models(filter="diffusers")
for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
print(f"Started running {mod.modelId}!!!")
if mod.modelId.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
else:
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
else:
model = UNet2DModel.from_pretrained(local_checkpoint)
torch.manual_seed(0)
random.seed(0)
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)['sample']
logits = model(noise, time_step)["sample"]
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
)
print(f"{mod.modelId} has passed succesfully!!!")