parent
365f75233f
commit
89793a97e2
2
Makefile
2
Makefile
|
@ -3,7 +3,7 @@
|
||||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||||
export PYTHONPATH = src
|
export PYTHONPATH = src
|
||||||
|
|
||||||
check_dirs := examples tests src utils
|
check_dirs := examples scripts src tests utils
|
||||||
|
|
||||||
modified_only_fixup:
|
modified_only_fixup:
|
||||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||||
|
|
|
@ -15,12 +15,15 @@
|
||||||
""" Conversion script for the LDM checkpoints. """
|
""" Conversion script for the LDM checkpoints. """
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DModel, UNet2DConditionModel
|
|
||||||
|
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||||
from transformers.file_utils import has_file
|
from transformers.file_utils import has_file
|
||||||
|
|
||||||
|
|
||||||
do_only_config = False
|
do_only_config = False
|
||||||
do_only_weights = True
|
do_only_weights = True
|
||||||
do_only_renaming = False
|
do_only_renaming = False
|
||||||
|
@ -37,9 +40,7 @@ if __name__ == "__main__":
|
||||||
help="The config json file corresponding to the architecture.",
|
help="The config json file corresponding to the architecture.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||||
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import OmegaConf
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler
|
import OmegaConf
|
||||||
|
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_original(checkpoint_path, config_path, output_path):
|
def convert_ldm_original(checkpoint_path, config_path, output_path):
|
||||||
config = OmegaConf.load(config_path)
|
config = OmegaConf.load(config_path)
|
||||||
|
@ -53,4 +54,3 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
|
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
|
||||||
|
|
||||||
|
|
|
@ -1,31 +1,33 @@
|
||||||
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
|
||||||
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
"""
|
"""
|
||||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
"""
|
"""
|
||||||
if n_shave_prefix_segments >= 0:
|
if n_shave_prefix_segments >= 0:
|
||||||
return '.'.join(path.split('.')[n_shave_prefix_segments:])
|
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||||
else:
|
else:
|
||||||
return '.'.join(path.split('.')[:n_shave_prefix_segments])
|
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||||
|
|
||||||
|
|
||||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||||
mapping = []
|
mapping = []
|
||||||
for old_item in old_list:
|
for old_item in old_list:
|
||||||
new_item = old_item
|
new_item = old_item
|
||||||
new_item = new_item.replace('block.', 'resnets.')
|
new_item = new_item.replace("block.", "resnets.")
|
||||||
new_item = new_item.replace('conv_shorcut', 'conv1')
|
new_item = new_item.replace("conv_shorcut", "conv1")
|
||||||
new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
|
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
||||||
new_item = new_item.replace('temb_proj', 'time_emb_proj')
|
new_item = new_item.replace("temb_proj", "time_emb_proj")
|
||||||
|
|
||||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
mapping.append({'old': old_item, 'new': new_item})
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
|
||||||
|
|
||||||
# In `model.mid`, the layer is called `attn`.
|
# In `model.mid`, the layer is called `attn`.
|
||||||
if not in_mid:
|
if not in_mid:
|
||||||
new_item = new_item.replace('attn', 'attentions')
|
new_item = new_item.replace("attn", "attentions")
|
||||||
new_item = new_item.replace('.k.', '.key.')
|
new_item = new_item.replace(".k.", ".key.")
|
||||||
new_item = new_item.replace('.v.', '.value.')
|
new_item = new_item.replace(".v.", ".value.")
|
||||||
new_item = new_item.replace('.q.', '.query.')
|
new_item = new_item.replace(".q.", ".query.")
|
||||||
|
|
||||||
new_item = new_item.replace('proj_out', 'proj_attn')
|
new_item = new_item.replace("proj_out", "proj_attn")
|
||||||
new_item = new_item.replace('norm', 'group_norm')
|
new_item = new_item.replace("norm", "group_norm")
|
||||||
|
|
||||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
mapping.append({'old': old_item, 'new': new_item})
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
def assign_to_checkpoint(
|
||||||
|
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||||
|
):
|
||||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||||
|
|
||||||
if attention_paths_to_split is not None:
|
if attention_paths_to_split is not None:
|
||||||
|
@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||||
|
|
||||||
checkpoint[path_map['query']] = query.reshape(target_shape).squeeze()
|
checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
|
||||||
checkpoint[path_map['key']] = key.reshape(target_shape).squeeze()
|
checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
|
||||||
checkpoint[path_map['value']] = value.reshape(target_shape).squeeze()
|
checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
new_path = path['new']
|
new_path = path["new"]
|
||||||
|
|
||||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_path = new_path.replace('down.', 'down_blocks.')
|
new_path = new_path.replace("down.", "down_blocks.")
|
||||||
new_path = new_path.replace('up.', 'up_blocks.')
|
new_path = new_path.replace("up.", "up_blocks.")
|
||||||
|
|
||||||
if additional_replacements is not None:
|
if additional_replacements is not None:
|
||||||
for replacement in additional_replacements:
|
for replacement in additional_replacements:
|
||||||
new_path = new_path.replace(replacement['old'], replacement['new'])
|
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||||
|
|
||||||
if 'attentions' in new_path:
|
if "attentions" in new_path:
|
||||||
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
|
checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
|
||||||
else:
|
else:
|
||||||
checkpoint[new_path] = old_checkpoint[path['old']]
|
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||||
|
|
||||||
|
|
||||||
def convert_ddpm_checkpoint(checkpoint, config):
|
def convert_ddpm_checkpoint(checkpoint, config):
|
||||||
|
@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||||
"""
|
"""
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
|
|
||||||
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight']
|
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
|
||||||
new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['temb.dense.0.bias']
|
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
|
||||||
new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['temb.dense.1.weight']
|
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
|
||||||
new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['temb.dense.1.bias']
|
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
|
||||||
|
|
||||||
new_checkpoint['conv_norm_out.weight'] = checkpoint['norm_out.weight']
|
new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
|
||||||
new_checkpoint['conv_norm_out.bias'] = checkpoint['norm_out.bias']
|
new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
|
||||||
|
|
||||||
new_checkpoint['conv_in.weight'] = checkpoint['conv_in.weight']
|
new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
|
||||||
new_checkpoint['conv_in.bias'] = checkpoint['conv_in.bias']
|
new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
|
||||||
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight']
|
new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
|
||||||
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias']
|
new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
|
||||||
|
|
||||||
num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
|
num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
|
||||||
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
down_blocks = {
|
||||||
|
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
|
num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
|
||||||
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
||||||
|
|
||||||
for i in range(num_down_blocks):
|
for i in range(num_down_blocks):
|
||||||
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||||
|
|
||||||
if any('downsample' in layer for layer in down_blocks[i]):
|
if any("downsample" in layer for layer in down_blocks[i]):
|
||||||
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight']
|
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
|
||||||
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias']
|
f"down.{i}.downsample.op.weight"
|
||||||
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
]
|
||||||
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
|
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
|
||||||
|
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
||||||
|
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
|
||||||
|
|
||||||
if any('block' in layer for layer in down_blocks[i]):
|
if any("block" in layer for layer in down_blocks[i]):
|
||||||
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer})
|
num_blocks = len(
|
||||||
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
|
||||||
|
)
|
||||||
|
blocks = {
|
||||||
|
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
if num_blocks > 0:
|
if num_blocks > 0:
|
||||||
for j in range(config['layers_per_block']):
|
for j in range(config["layers_per_block"]):
|
||||||
paths = renew_resnet_paths(blocks[j])
|
paths = renew_resnet_paths(blocks[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
||||||
|
|
||||||
if any('attn' in layer for layer in down_blocks[i]):
|
if any("attn" in layer for layer in down_blocks[i]):
|
||||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer})
|
num_attn = len(
|
||||||
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
|
||||||
|
)
|
||||||
|
attns = {
|
||||||
|
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
if num_attn > 0:
|
if num_attn > 0:
|
||||||
for j in range(config['layers_per_block']):
|
for j in range(config["layers_per_block"]):
|
||||||
paths = renew_attention_paths(attns[j])
|
paths = renew_attention_paths(attns[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
||||||
|
|
||||||
|
@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||||
|
|
||||||
# Mid new 2
|
# Mid new 2
|
||||||
paths = renew_resnet_paths(mid_block_1_layers)
|
paths = renew_resnet_paths(mid_block_1_layers)
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
assign_to_checkpoint(
|
||||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
|
paths,
|
||||||
])
|
new_checkpoint,
|
||||||
|
checkpoint,
|
||||||
|
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
|
||||||
|
)
|
||||||
|
|
||||||
paths = renew_resnet_paths(mid_block_2_layers)
|
paths = renew_resnet_paths(mid_block_2_layers)
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
assign_to_checkpoint(
|
||||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
|
paths,
|
||||||
])
|
new_checkpoint,
|
||||||
|
checkpoint,
|
||||||
|
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
|
||||||
|
)
|
||||||
|
|
||||||
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
assign_to_checkpoint(
|
||||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
|
paths,
|
||||||
])
|
new_checkpoint,
|
||||||
|
checkpoint,
|
||||||
|
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(num_up_blocks):
|
for i in range(num_up_blocks):
|
||||||
block_id = num_up_blocks - 1 - i
|
block_id = num_up_blocks - 1 - i
|
||||||
|
|
||||||
if any('upsample' in layer for layer in up_blocks[i]):
|
if any("upsample" in layer for layer in up_blocks[i]):
|
||||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
||||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
|
f"up.{i}.upsample.conv.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
|
||||||
|
|
||||||
if any('block' in layer for layer in up_blocks[i]):
|
if any("block" in layer for layer in up_blocks[i]):
|
||||||
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
|
num_blocks = len(
|
||||||
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
|
||||||
|
)
|
||||||
|
blocks = {
|
||||||
|
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
if num_blocks > 0:
|
if num_blocks > 0:
|
||||||
for j in range(config['layers_per_block'] + 1):
|
for j in range(config["layers_per_block"] + 1):
|
||||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
||||||
paths = renew_resnet_paths(blocks[j])
|
paths = renew_resnet_paths(blocks[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||||
|
|
||||||
if any('attn' in layer for layer in up_blocks[i]):
|
if any("attn" in layer for layer in up_blocks[i]):
|
||||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
|
num_attn = len(
|
||||||
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
|
||||||
|
)
|
||||||
|
attns = {
|
||||||
|
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
if num_attn > 0:
|
if num_attn > 0:
|
||||||
for j in range(config['layers_per_block'] + 1):
|
for j in range(config["layers_per_block"] + 1):
|
||||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
||||||
paths = renew_attention_paths(attns[j])
|
paths = renew_attention_paths(attns[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||||
|
|
||||||
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
|
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
|
||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
|
||||||
"""
|
"""
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
|
|
||||||
new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight']
|
new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
|
||||||
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias']
|
new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
|
||||||
|
|
||||||
new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight']
|
new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
|
||||||
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias']
|
new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
|
||||||
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight']
|
new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
|
||||||
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias']
|
new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
|
||||||
|
|
||||||
new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight']
|
new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
|
||||||
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias']
|
new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
|
||||||
|
|
||||||
new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight']
|
new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
|
||||||
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias']
|
new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
|
||||||
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight']
|
new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
|
||||||
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias']
|
new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
|
||||||
|
|
||||||
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer})
|
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
|
||||||
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
down_blocks = {
|
||||||
|
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer})
|
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
|
||||||
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
||||||
|
|
||||||
for i in range(num_down_blocks):
|
for i in range(num_down_blocks):
|
||||||
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||||
|
|
||||||
if any('downsample' in layer for layer in down_blocks[i]):
|
if any("downsample" in layer for layer in down_blocks[i]):
|
||||||
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight']
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
|
||||||
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias']
|
f"encoder.down.{i}.downsample.conv.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
|
||||||
|
f"encoder.down.{i}.downsample.conv.bias"
|
||||||
|
]
|
||||||
|
|
||||||
if any('block' in layer for layer in down_blocks[i]):
|
if any("block" in layer for layer in down_blocks[i]):
|
||||||
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer})
|
num_blocks = len(
|
||||||
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
|
||||||
|
)
|
||||||
|
blocks = {
|
||||||
|
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
if num_blocks > 0:
|
if num_blocks > 0:
|
||||||
for j in range(config['layers_per_block']):
|
for j in range(config["layers_per_block"]):
|
||||||
paths = renew_resnet_paths(blocks[j])
|
paths = renew_resnet_paths(blocks[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
||||||
|
|
||||||
if any('attn' in layer for layer in down_blocks[i]):
|
if any("attn" in layer for layer in down_blocks[i]):
|
||||||
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer})
|
num_attn = len(
|
||||||
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
|
||||||
|
)
|
||||||
|
attns = {
|
||||||
|
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
if num_attn > 0:
|
if num_attn > 0:
|
||||||
for j in range(config['layers_per_block']):
|
for j in range(config["layers_per_block"]):
|
||||||
paths = renew_attention_paths(attns[j])
|
paths = renew_attention_paths(attns[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
||||||
|
|
||||||
|
@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
|
||||||
|
|
||||||
# Mid new 2
|
# Mid new 2
|
||||||
paths = renew_resnet_paths(mid_block_1_layers)
|
paths = renew_resnet_paths(mid_block_1_layers)
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
assign_to_checkpoint(
|
||||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
|
paths,
|
||||||
])
|
new_checkpoint,
|
||||||
|
checkpoint,
|
||||||
|
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
|
||||||
|
)
|
||||||
|
|
||||||
paths = renew_resnet_paths(mid_block_2_layers)
|
paths = renew_resnet_paths(mid_block_2_layers)
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
assign_to_checkpoint(
|
||||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
|
paths,
|
||||||
])
|
new_checkpoint,
|
||||||
|
checkpoint,
|
||||||
|
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
|
||||||
|
)
|
||||||
|
|
||||||
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
assign_to_checkpoint(
|
||||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
|
paths,
|
||||||
])
|
new_checkpoint,
|
||||||
|
checkpoint,
|
||||||
|
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(num_up_blocks):
|
for i in range(num_up_blocks):
|
||||||
block_id = num_up_blocks - 1 - i
|
block_id = num_up_blocks - 1 - i
|
||||||
|
|
||||||
if any('upsample' in layer for layer in up_blocks[i]):
|
if any("upsample" in layer for layer in up_blocks[i]):
|
||||||
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight']
|
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
||||||
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias']
|
f"decoder.up.{i}.upsample.conv.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
|
||||||
|
f"decoder.up.{i}.upsample.conv.bias"
|
||||||
|
]
|
||||||
|
|
||||||
if any('block' in layer for layer in up_blocks[i]):
|
if any("block" in layer for layer in up_blocks[i]):
|
||||||
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer})
|
num_blocks = len(
|
||||||
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
|
||||||
|
)
|
||||||
|
blocks = {
|
||||||
|
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
if num_blocks > 0:
|
if num_blocks > 0:
|
||||||
for j in range(config['layers_per_block'] + 1):
|
for j in range(config["layers_per_block"] + 1):
|
||||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
||||||
paths = renew_resnet_paths(blocks[j])
|
paths = renew_resnet_paths(blocks[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||||
|
|
||||||
if any('attn' in layer for layer in up_blocks[i]):
|
if any("attn" in layer for layer in up_blocks[i]):
|
||||||
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer})
|
num_attn = len(
|
||||||
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
|
||||||
|
)
|
||||||
|
attns = {
|
||||||
|
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
if num_attn > 0:
|
if num_attn > 0:
|
||||||
for j in range(config['layers_per_block'] + 1):
|
for j in range(config["layers_per_block"] + 1):
|
||||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
||||||
paths = renew_attention_paths(attns[j])
|
paths = renew_attention_paths(attns[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||||
|
|
||||||
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
|
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
|
||||||
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
|
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
|
||||||
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
|
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
|
||||||
if "quantize.embedding.weight" in checkpoint:
|
if "quantize.embedding.weight" in checkpoint:
|
||||||
|
@ -321,9 +395,7 @@ if __name__ == "__main__":
|
||||||
help="The config json file corresponding to the architecture.",
|
help="The config json file corresponding to the architecture.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||||
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
checkpoint = torch.load(args.checkpoint_path)
|
checkpoint = torch.load(args.checkpoint_path)
|
||||||
|
|
|
@ -16,8 +16,10 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LDMPipeline
|
|
||||||
|
from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
|
||||||
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
|
@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
"""
|
"""
|
||||||
if n_shave_prefix_segments >= 0:
|
if n_shave_prefix_segments >= 0:
|
||||||
return '.'.join(path.split('.')[n_shave_prefix_segments:])
|
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||||
else:
|
else:
|
||||||
return '.'.join(path.split('.')[:n_shave_prefix_segments])
|
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||||
|
|
||||||
|
|
||||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||||
"""
|
"""
|
||||||
mapping = []
|
mapping = []
|
||||||
for old_item in old_list:
|
for old_item in old_list:
|
||||||
new_item = old_item.replace('in_layers.0', 'norm1')
|
new_item = old_item.replace("in_layers.0", "norm1")
|
||||||
new_item = new_item.replace('in_layers.2', 'conv1')
|
new_item = new_item.replace("in_layers.2", "conv1")
|
||||||
|
|
||||||
new_item = new_item.replace('out_layers.0', 'norm2')
|
new_item = new_item.replace("out_layers.0", "norm2")
|
||||||
new_item = new_item.replace('out_layers.3', 'conv2')
|
new_item = new_item.replace("out_layers.3", "conv2")
|
||||||
|
|
||||||
new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
|
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
||||||
new_item = new_item.replace('skip_connection', 'conv_shortcut')
|
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||||
|
|
||||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
mapping.append({'old': old_item, 'new': new_item})
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||||
for old_item in old_list:
|
for old_item in old_list:
|
||||||
new_item = old_item
|
new_item = old_item
|
||||||
|
|
||||||
new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||||
new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||||
|
|
||||||
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||||
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||||
|
|
||||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
mapping.append({'old': old_item, 'new': new_item})
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
def assign_to_checkpoint(
|
||||||
|
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
This does the final conversion step: take locally converted weights and apply a global renaming
|
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||||
to them. It splits attention layers, and takes into account additional replacements
|
to them. It splits attention layers, and takes into account additional replacements
|
||||||
|
@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||||
|
|
||||||
checkpoint[path_map['query']] = query.reshape(target_shape)
|
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||||
checkpoint[path_map['key']] = key.reshape(target_shape)
|
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||||
checkpoint[path_map['value']] = value.reshape(target_shape)
|
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
new_path = path['new']
|
new_path = path["new"]
|
||||||
|
|
||||||
# These have already been assigned
|
# These have already been assigned
|
||||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Global renaming happens here
|
# Global renaming happens here
|
||||||
new_path = new_path.replace('middle_block.0', 'mid.resnets.0')
|
new_path = new_path.replace("middle_block.0", "mid.resnets.0")
|
||||||
new_path = new_path.replace('middle_block.1', 'mid.attentions.0')
|
new_path = new_path.replace("middle_block.1", "mid.attentions.0")
|
||||||
new_path = new_path.replace('middle_block.2', 'mid.resnets.1')
|
new_path = new_path.replace("middle_block.2", "mid.resnets.1")
|
||||||
|
|
||||||
if additional_replacements is not None:
|
if additional_replacements is not None:
|
||||||
for replacement in additional_replacements:
|
for replacement in additional_replacements:
|
||||||
new_path = new_path.replace(replacement['old'], replacement['new'])
|
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||||
|
|
||||||
# proj_attn.weight has to be converted from conv 1D to linear
|
# proj_attn.weight has to be converted from conv 1D to linear
|
||||||
if "proj_attn.weight" in new_path:
|
if "proj_attn.weight" in new_path:
|
||||||
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||||
else:
|
else:
|
||||||
checkpoint[new_path] = old_checkpoint[path['old']]
|
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_checkpoint(checkpoint, config):
|
def convert_ldm_checkpoint(checkpoint, config):
|
||||||
|
@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||||
"""
|
"""
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
|
|
||||||
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['time_embed.0.weight']
|
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
|
||||||
new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['time_embed.0.bias']
|
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
|
||||||
new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['time_embed.2.weight']
|
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
|
||||||
new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['time_embed.2.bias']
|
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
|
||||||
|
|
||||||
new_checkpoint['conv_in.weight'] = checkpoint['input_blocks.0.0.weight']
|
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
|
||||||
new_checkpoint['conv_in.bias'] = checkpoint['input_blocks.0.0.bias']
|
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
|
||||||
|
|
||||||
new_checkpoint['conv_norm_out.weight'] = checkpoint['out.0.weight']
|
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
||||||
new_checkpoint['conv_norm_out.bias'] = checkpoint['out.0.bias']
|
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
|
||||||
new_checkpoint['conv_out.weight'] = checkpoint['out.2.weight']
|
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
|
||||||
new_checkpoint['conv_out.bias'] = checkpoint['out.2.bias']
|
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
|
||||||
|
|
||||||
# Retrieves the keys for the input blocks only
|
# Retrieves the keys for the input blocks only
|
||||||
num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'input_blocks' in layer})
|
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
|
||||||
input_blocks = {layer_id: [key for key in checkpoint if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)}
|
input_blocks = {
|
||||||
|
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_input_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
# Retrieves the keys for the middle blocks only
|
# Retrieves the keys for the middle blocks only
|
||||||
num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'middle_block' in layer})
|
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
|
||||||
middle_blocks = {layer_id: [key for key in checkpoint if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)}
|
middle_blocks = {
|
||||||
|
layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_middle_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
# Retrieves the keys for the output blocks only
|
# Retrieves the keys for the output blocks only
|
||||||
num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'output_blocks' in layer})
|
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
|
||||||
output_blocks = {layer_id: [key for key in checkpoint if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)}
|
output_blocks = {
|
||||||
|
layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_output_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
for i in range(1, num_input_blocks):
|
for i in range(1, num_input_blocks):
|
||||||
block_id = (i - 1) // (config['num_res_blocks'] + 1)
|
block_id = (i - 1) // (config["num_res_blocks"] + 1)
|
||||||
layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1)
|
layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
|
||||||
|
|
||||||
resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key]
|
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
|
||||||
attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key]
|
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||||
|
|
||||||
if f'input_blocks.{i}.0.op.weight' in checkpoint:
|
if f"input_blocks.{i}.0.op.weight" in checkpoint:
|
||||||
new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.weight'] = checkpoint[f'input_blocks.{i}.0.op.weight']
|
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
|
||||||
new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.bias'] = checkpoint[f'input_blocks.{i}.0.op.bias']
|
f"input_blocks.{i}.0.op.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
|
||||||
|
f"input_blocks.{i}.0.op.bias"
|
||||||
|
]
|
||||||
|
|
||||||
paths = renew_resnet_paths(resnets)
|
paths = renew_resnet_paths(resnets)
|
||||||
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||||
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'}
|
resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config)
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
|
||||||
|
)
|
||||||
|
|
||||||
if len(attentions):
|
if len(attentions):
|
||||||
paths = renew_attention_paths(attentions)
|
paths = renew_attention_paths(attentions)
|
||||||
meta_path = {'old': f'input_blocks.{i}.1', 'new': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}'}
|
meta_path = {
|
||||||
|
"old": f"input_blocks.{i}.1",
|
||||||
|
"new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||||
|
}
|
||||||
to_split = {
|
to_split = {
|
||||||
f'input_blocks.{i}.1.qkv.bias': {
|
f"input_blocks.{i}.1.qkv.bias": {
|
||||||
'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
|
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
|
||||||
'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
|
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
|
||||||
'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
|
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
|
||||||
},
|
},
|
||||||
f'input_blocks.{i}.1.qkv.weight': {
|
f"input_blocks.{i}.1.qkv.weight": {
|
||||||
'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
|
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
|
||||||
'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
|
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
|
||||||
'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
|
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assign_to_checkpoint(
|
assign_to_checkpoint(
|
||||||
|
@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||||
checkpoint,
|
checkpoint,
|
||||||
additional_replacements=[meta_path],
|
additional_replacements=[meta_path],
|
||||||
attention_paths_to_split=to_split,
|
attention_paths_to_split=to_split,
|
||||||
config=config
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
resnet_0 = middle_blocks[0]
|
resnet_0 = middle_blocks[0]
|
||||||
|
@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||||
|
|
||||||
attentions_paths = renew_attention_paths(attentions)
|
attentions_paths = renew_attention_paths(attentions)
|
||||||
to_split = {
|
to_split = {
|
||||||
'middle_block.1.qkv.bias': {
|
"middle_block.1.qkv.bias": {
|
||||||
'key': 'mid_block.attentions.0.key.bias',
|
"key": "mid_block.attentions.0.key.bias",
|
||||||
'query': 'mid_block.attentions.0.query.bias',
|
"query": "mid_block.attentions.0.query.bias",
|
||||||
'value': 'mid_block.attentions.0.value.bias',
|
"value": "mid_block.attentions.0.value.bias",
|
||||||
},
|
},
|
||||||
'middle_block.1.qkv.weight': {
|
"middle_block.1.qkv.weight": {
|
||||||
'key': 'mid_block.attentions.0.key.weight',
|
"key": "mid_block.attentions.0.key.weight",
|
||||||
'query': 'mid_block.attentions.0.query.weight',
|
"query": "mid_block.attentions.0.query.weight",
|
||||||
'value': 'mid_block.attentions.0.value.weight',
|
"value": "mid_block.attentions.0.value.weight",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
|
assign_to_checkpoint(
|
||||||
|
attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(num_output_blocks):
|
for i in range(num_output_blocks):
|
||||||
block_id = i // (config['num_res_blocks'] + 1)
|
block_id = i // (config["num_res_blocks"] + 1)
|
||||||
layer_in_block_id = i % (config['num_res_blocks'] + 1)
|
layer_in_block_id = i % (config["num_res_blocks"] + 1)
|
||||||
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||||
output_block_list = {}
|
output_block_list = {}
|
||||||
|
|
||||||
for layer in output_block_layers:
|
for layer in output_block_layers:
|
||||||
layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1)
|
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
||||||
if layer_id in output_block_list:
|
if layer_id in output_block_list:
|
||||||
output_block_list[layer_id].append(layer_name)
|
output_block_list[layer_id].append(layer_name)
|
||||||
else:
|
else:
|
||||||
output_block_list[layer_id] = [layer_name]
|
output_block_list[layer_id] = [layer_name]
|
||||||
|
|
||||||
if len(output_block_list) > 1:
|
if len(output_block_list) > 1:
|
||||||
resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key]
|
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
||||||
attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key]
|
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
||||||
|
|
||||||
resnet_0_paths = renew_resnet_paths(resnets)
|
resnet_0_paths = renew_resnet_paths(resnets)
|
||||||
paths = renew_resnet_paths(resnets)
|
paths = renew_resnet_paths(resnets)
|
||||||
|
|
||||||
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||||
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
||||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
|
f"output_blocks.{i}.{index}.conv.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
|
||||||
|
f"output_blocks.{i}.{index}.conv.bias"
|
||||||
|
]
|
||||||
|
|
||||||
# Clear attentions as they have been attributed above.
|
# Clear attentions as they have been attributed above.
|
||||||
if len(attentions) == 2:
|
if len(attentions) == 2:
|
||||||
|
@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||||
if len(attentions):
|
if len(attentions):
|
||||||
paths = renew_attention_paths(attentions)
|
paths = renew_attention_paths(attentions)
|
||||||
meta_path = {
|
meta_path = {
|
||||||
'old': f'output_blocks.{i}.1',
|
"old": f"output_blocks.{i}.1",
|
||||||
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
|
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||||
}
|
}
|
||||||
to_split = {
|
to_split = {
|
||||||
f'output_blocks.{i}.1.qkv.bias': {
|
f"output_blocks.{i}.1.qkv.bias": {
|
||||||
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
|
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
|
||||||
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
|
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
|
||||||
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
|
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
|
||||||
},
|
},
|
||||||
f'output_blocks.{i}.1.qkv.weight': {
|
f"output_blocks.{i}.1.qkv.weight": {
|
||||||
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
|
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
|
||||||
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
|
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
|
||||||
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
|
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assign_to_checkpoint(
|
assign_to_checkpoint(
|
||||||
|
@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||||
new_checkpoint,
|
new_checkpoint,
|
||||||
checkpoint,
|
checkpoint,
|
||||||
additional_replacements=[meta_path],
|
additional_replacements=[meta_path],
|
||||||
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None,
|
attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||||
for path in resnet_0_paths:
|
for path in resnet_0_paths:
|
||||||
old_path = '.'.join(['output_blocks', str(i), path['old']])
|
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||||
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
|
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
||||||
|
|
||||||
new_checkpoint[new_path] = checkpoint[old_path]
|
new_checkpoint[new_path] = checkpoint[old_path]
|
||||||
|
|
||||||
|
@ -303,9 +331,7 @@ if __name__ == "__main__":
|
||||||
help="The config json file corresponding to the architecture.",
|
help="The config json file corresponding to the architecture.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||||
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,10 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DModel
|
|
||||||
|
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
|
||||||
|
|
||||||
|
|
||||||
def convert_ncsnpp_checkpoint(checkpoint, config):
|
def convert_ncsnpp_checkpoint(checkpoint, config):
|
||||||
|
|
|
@ -1,71 +1,105 @@
|
||||||
from huggingface_hub import HfApi
|
|
||||||
from transformers.file_utils import has_file
|
|
||||||
from diffusers import UNet2DModel
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers import UNet2DModel
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
results["google_ddpm_cifar10_32"] = torch.tensor([-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
|
# fmt: off
|
||||||
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
|
results["google_ddpm_cifar10_32"] = torch.tensor([
|
||||||
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
|
-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
|
||||||
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557])
|
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
|
||||||
results["google_ddpm_ema_bedroom_256"] = torch.tensor([-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
|
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
|
||||||
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
|
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557
|
||||||
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
|
])
|
||||||
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365])
|
results["google_ddpm_ema_bedroom_256"] = torch.tensor([
|
||||||
results["CompVis_ldm_celebahq_256"] = torch.tensor([-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
|
-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
|
||||||
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
|
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
|
||||||
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
|
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
|
||||||
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943])
|
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365
|
||||||
results["google_ncsnpp_ffhq_1024"] = torch.tensor([ 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
|
])
|
||||||
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
|
results["CompVis_ldm_celebahq_256"] = torch.tensor([
|
||||||
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
|
-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
|
||||||
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505])
|
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
|
||||||
results["google_ncsnpp_bedroom_256"] = torch.tensor([ 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
|
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
|
||||||
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
|
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943
|
||||||
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
|
])
|
||||||
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386])
|
results["google_ncsnpp_ffhq_1024"] = torch.tensor([
|
||||||
results["google_ncsnpp_celebahq_256"] = torch.tensor([ 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
|
0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
|
||||||
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
|
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
|
||||||
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
|
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
|
||||||
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431])
|
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505
|
||||||
results["google_ncsnpp_church_256"] = torch.tensor([ 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
|
])
|
||||||
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
|
results["google_ncsnpp_bedroom_256"] = torch.tensor([
|
||||||
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
|
0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
|
||||||
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390])
|
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
|
||||||
results["google_ncsnpp_ffhq_256"] = torch.tensor([ 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
|
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
|
||||||
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
|
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386
|
||||||
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
|
])
|
||||||
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473])
|
results["google_ncsnpp_celebahq_256"] = torch.tensor([
|
||||||
results["google_ddpm_cat_256"] = torch.tensor([-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
|
0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
|
||||||
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
|
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
|
||||||
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
|
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
|
||||||
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
|
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431
|
||||||
results["google_ddpm_celebahq_256"] = torch.tensor([-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
|
])
|
||||||
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
|
results["google_ncsnpp_church_256"] = torch.tensor([
|
||||||
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
|
0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
|
||||||
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266])
|
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
|
||||||
results["google_ddpm_ema_celebahq_256"] = torch.tensor([-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
|
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
|
||||||
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
|
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390
|
||||||
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
|
])
|
||||||
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355])
|
results["google_ncsnpp_ffhq_256"] = torch.tensor([
|
||||||
results["google_ddpm_church_256"] = torch.tensor([-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
|
0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
|
||||||
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
|
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
|
||||||
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
|
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
|
||||||
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066])
|
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473
|
||||||
results["google_ddpm_bedroom_256"] = torch.tensor([-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
|
])
|
||||||
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
|
results["google_ddpm_cat_256"] = torch.tensor([
|
||||||
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
|
-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
|
||||||
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243])
|
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
|
||||||
results["google_ddpm_ema_church_256"] = torch.tensor([-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
|
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
|
||||||
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
|
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
|
||||||
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
|
results["google_ddpm_celebahq_256"] = torch.tensor([
|
||||||
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343])
|
-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
|
||||||
results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
|
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
|
||||||
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
|
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
|
||||||
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
|
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266
|
||||||
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219])
|
])
|
||||||
|
results["google_ddpm_ema_celebahq_256"] = torch.tensor([
|
||||||
|
-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
|
||||||
|
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
|
||||||
|
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
|
||||||
|
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355
|
||||||
|
])
|
||||||
|
results["google_ddpm_church_256"] = torch.tensor([
|
||||||
|
-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
|
||||||
|
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
|
||||||
|
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
|
||||||
|
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066
|
||||||
|
])
|
||||||
|
results["google_ddpm_bedroom_256"] = torch.tensor([
|
||||||
|
-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
|
||||||
|
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
|
||||||
|
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
|
||||||
|
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243
|
||||||
|
])
|
||||||
|
results["google_ddpm_ema_church_256"] = torch.tensor([
|
||||||
|
-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
|
||||||
|
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
|
||||||
|
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
|
||||||
|
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343
|
||||||
|
])
|
||||||
|
results["google_ddpm_ema_cat_256"] = torch.tensor([
|
||||||
|
-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
|
||||||
|
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
|
||||||
|
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
|
||||||
|
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
models = api.list_models(filter="diffusers")
|
models = api.list_models(filter="diffusers")
|
||||||
for mod in models:
|
for mod in models:
|
||||||
|
@ -75,7 +109,7 @@ for mod in models:
|
||||||
print(f"Started running {mod.modelId}!!!")
|
print(f"Started running {mod.modelId}!!!")
|
||||||
|
|
||||||
if mod.modelId.startswith("CompVis"):
|
if mod.modelId.startswith("CompVis"):
|
||||||
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
|
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
|
||||||
else:
|
else:
|
||||||
model = UNet2DModel.from_pretrained(local_checkpoint)
|
model = UNet2DModel.from_pretrained(local_checkpoint)
|
||||||
|
|
||||||
|
@ -85,7 +119,9 @@ for mod in models:
|
||||||
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||||
time_step = torch.tensor([10] * noise.shape[0])
|
time_step = torch.tensor([10] * noise.shape[0])
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(noise, time_step)['sample']
|
logits = model(noise, time_step)["sample"]
|
||||||
|
|
||||||
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
|
assert torch.allclose(
|
||||||
|
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
|
||||||
|
)
|
||||||
print(f"{mod.modelId} has passed succesfully!!!")
|
print(f"{mod.modelId} has passed succesfully!!!")
|
||||||
|
|
Loading…
Reference in New Issue