parent
365f75233f
commit
89793a97e2
2
Makefile
2
Makefile
|
@ -3,7 +3,7 @@
|
|||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
|
||||
check_dirs := examples tests src utils
|
||||
check_dirs := examples scripts src tests utils
|
||||
|
||||
modified_only_fixup:
|
||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||
|
|
|
@ -15,12 +15,15 @@
|
|||
""" Conversion script for the LDM checkpoints. """
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DModel, UNet2DConditionModel
|
||||
|
||||
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||
from transformers.file_utils import has_file
|
||||
|
||||
|
||||
do_only_config = False
|
||||
do_only_weights = True
|
||||
do_only_renaming = False
|
||||
|
@ -37,9 +40,7 @@ if __name__ == "__main__":
|
|||
help="The config json file corresponding to the architecture.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import argparse
|
||||
|
||||
import OmegaConf
|
||||
import torch
|
||||
|
||||
from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler
|
||||
import OmegaConf
|
||||
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
|
||||
|
||||
|
||||
def convert_ldm_original(checkpoint_path, config_path, output_path):
|
||||
config = OmegaConf.load(config_path)
|
||||
|
@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
|
|||
for key in keys:
|
||||
if key.startswith(first_stage_key):
|
||||
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
|
||||
|
||||
|
||||
# extract state_dict for UNetLDM
|
||||
unet_state_dict = {}
|
||||
unet_key = "model.diffusion_model."
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
|
||||
|
||||
|
||||
vqvae_init_args = config.model.params.first_stage_config.params
|
||||
unet_init_args = config.model.params.unet_config.params
|
||||
|
||||
|
@ -53,4 +54,3 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
|
||||
|
||||
|
|
|
@ -1,31 +1,33 @@
|
|||
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
"""
|
||||
if n_shave_prefix_segments >= 0:
|
||||
return '.'.join(path.split('.')[n_shave_prefix_segments:])
|
||||
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||
else:
|
||||
return '.'.join(path.split('.')[:n_shave_prefix_segments])
|
||||
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||
|
||||
|
||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
new_item = new_item.replace('block.', 'resnets.')
|
||||
new_item = new_item.replace('conv_shorcut', 'conv1')
|
||||
new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
|
||||
new_item = new_item.replace('temb_proj', 'time_emb_proj')
|
||||
new_item = new_item.replace("block.", "resnets.")
|
||||
new_item = new_item.replace("conv_shorcut", "conv1")
|
||||
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
||||
new_item = new_item.replace("temb_proj", "time_emb_proj")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({'old': old_item, 'new': new_item})
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
|
|||
|
||||
# In `model.mid`, the layer is called `attn`.
|
||||
if not in_mid:
|
||||
new_item = new_item.replace('attn', 'attentions')
|
||||
new_item = new_item.replace('.k.', '.key.')
|
||||
new_item = new_item.replace('.v.', '.value.')
|
||||
new_item = new_item.replace('.q.', '.query.')
|
||||
new_item = new_item.replace("attn", "attentions")
|
||||
new_item = new_item.replace(".k.", ".key.")
|
||||
new_item = new_item.replace(".v.", ".value.")
|
||||
new_item = new_item.replace(".q.", ".query.")
|
||||
|
||||
new_item = new_item.replace('proj_out', 'proj_attn')
|
||||
new_item = new_item.replace('norm', 'group_norm')
|
||||
new_item = new_item.replace("proj_out", "proj_attn")
|
||||
new_item = new_item.replace("norm", "group_norm")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
mapping.append({'old': old_item, 'new': new_item})
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
||||
def assign_to_checkpoint(
|
||||
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||
):
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
if attention_paths_to_split is not None:
|
||||
|
@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
|||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map['query']] = query.reshape(target_shape).squeeze()
|
||||
checkpoint[path_map['key']] = key.reshape(target_shape).squeeze()
|
||||
checkpoint[path_map['value']] = value.reshape(target_shape).squeeze()
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
|
||||
|
||||
for path in paths:
|
||||
new_path = path['new']
|
||||
new_path = path["new"]
|
||||
|
||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||
continue
|
||||
|
||||
new_path = new_path.replace('down.', 'down_blocks.')
|
||||
new_path = new_path.replace('up.', 'up_blocks.')
|
||||
new_path = new_path.replace("down.", "down_blocks.")
|
||||
new_path = new_path.replace("up.", "up_blocks.")
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
new_path = new_path.replace(replacement['old'], replacement['new'])
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
if 'attentions' in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
|
||||
if "attentions" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path['old']]
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
def convert_ddpm_checkpoint(checkpoint, config):
|
||||
|
@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
|||
"""
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight']
|
||||
new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['temb.dense.0.bias']
|
||||
new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['temb.dense.1.weight']
|
||||
new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['temb.dense.1.bias']
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
|
||||
|
||||
new_checkpoint['conv_norm_out.weight'] = checkpoint['norm_out.weight']
|
||||
new_checkpoint['conv_norm_out.bias'] = checkpoint['norm_out.bias']
|
||||
new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
|
||||
|
||||
new_checkpoint['conv_in.weight'] = checkpoint['conv_in.weight']
|
||||
new_checkpoint['conv_in.bias'] = checkpoint['conv_in.bias']
|
||||
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight']
|
||||
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias']
|
||||
new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
|
||||
new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
|
||||
new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
|
||||
new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
|
||||
|
||||
num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
|
||||
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
||||
num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
|
||||
down_blocks = {
|
||||
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
|
||||
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
||||
num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
|
||||
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
|
||||
if any('downsample' in layer for layer in down_blocks[i]):
|
||||
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight']
|
||||
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias']
|
||||
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
||||
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
|
||||
if any("downsample" in layer for layer in down_blocks[i]):
|
||||
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
|
||||
f"down.{i}.downsample.op.weight"
|
||||
]
|
||||
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
|
||||
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
||||
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
|
||||
|
||||
if any('block' in layer for layer in down_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any("block" in layer for layer in down_blocks[i]):
|
||||
num_blocks = len(
|
||||
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
|
||||
)
|
||||
blocks = {
|
||||
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
|
||||
for layer_id in range(num_blocks)
|
||||
}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['layers_per_block']):
|
||||
for j in range(config["layers_per_block"]):
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
||||
|
||||
if any('attn' in layer for layer in down_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any("attn" in layer for layer in down_blocks[i]):
|
||||
num_attn = len(
|
||||
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
|
||||
)
|
||||
attns = {
|
||||
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
|
||||
for layer_id in range(num_blocks)
|
||||
}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['layers_per_block']):
|
||||
for j in range(config["layers_per_block"]):
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
||||
|
||||
|
@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
|||
|
||||
# Mid new 2
|
||||
paths = renew_resnet_paths(mid_block_1_layers)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
|
||||
])
|
||||
assign_to_checkpoint(
|
||||
paths,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
|
||||
)
|
||||
|
||||
paths = renew_resnet_paths(mid_block_2_layers)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
|
||||
])
|
||||
assign_to_checkpoint(
|
||||
paths,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
|
||||
)
|
||||
|
||||
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
|
||||
])
|
||||
assign_to_checkpoint(
|
||||
paths,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
|
||||
)
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
|
||||
if any('upsample' in layer for layer in up_blocks[i]):
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
|
||||
if any("upsample" in layer for layer in up_blocks[i]):
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
||||
f"up.{i}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
|
||||
|
||||
if any('block' in layer for layer in up_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any("block" in layer for layer in up_blocks[i]):
|
||||
num_blocks = len(
|
||||
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
|
||||
)
|
||||
blocks = {
|
||||
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
|
||||
}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['layers_per_block'] + 1):
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
for j in range(config["layers_per_block"] + 1):
|
||||
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
|
||||
if any('attn' in layer for layer in up_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any("attn" in layer for layer in up_blocks[i]):
|
||||
num_attn = len(
|
||||
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
|
||||
)
|
||||
attns = {
|
||||
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
|
||||
}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['layers_per_block'] + 1):
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
for j in range(config["layers_per_block"] + 1):
|
||||
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
|
||||
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
|
||||
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
|
@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
|
|||
"""
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight']
|
||||
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias']
|
||||
new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
|
||||
new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight']
|
||||
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias']
|
||||
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight']
|
||||
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias']
|
||||
new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
|
||||
new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
|
||||
new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
|
||||
new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
|
||||
|
||||
new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight']
|
||||
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias']
|
||||
new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
|
||||
new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight']
|
||||
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias']
|
||||
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight']
|
||||
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias']
|
||||
new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
|
||||
new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
|
||||
new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
|
||||
new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
|
||||
|
||||
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer})
|
||||
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
||||
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
|
||||
down_blocks = {
|
||||
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer})
|
||||
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
||||
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
|
||||
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
|
||||
if any('downsample' in layer for layer in down_blocks[i]):
|
||||
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight']
|
||||
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias']
|
||||
if any("downsample" in layer for layer in down_blocks[i]):
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
]
|
||||
|
||||
if any('block' in layer for layer in down_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any("block" in layer for layer in down_blocks[i]):
|
||||
num_blocks = len(
|
||||
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
|
||||
)
|
||||
blocks = {
|
||||
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
|
||||
for layer_id in range(num_blocks)
|
||||
}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['layers_per_block']):
|
||||
for j in range(config["layers_per_block"]):
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
||||
|
||||
if any('attn' in layer for layer in down_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any("attn" in layer for layer in down_blocks[i]):
|
||||
num_attn = len(
|
||||
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
|
||||
)
|
||||
attns = {
|
||||
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
|
||||
for layer_id in range(num_blocks)
|
||||
}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['layers_per_block']):
|
||||
for j in range(config["layers_per_block"]):
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
||||
|
||||
|
@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
|
|||
|
||||
# Mid new 2
|
||||
paths = renew_resnet_paths(mid_block_1_layers)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
|
||||
])
|
||||
assign_to_checkpoint(
|
||||
paths,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
|
||||
)
|
||||
|
||||
paths = renew_resnet_paths(mid_block_2_layers)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
|
||||
])
|
||||
assign_to_checkpoint(
|
||||
paths,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
|
||||
)
|
||||
|
||||
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
|
||||
])
|
||||
assign_to_checkpoint(
|
||||
paths,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
|
||||
)
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
|
||||
if any('upsample' in layer for layer in up_blocks[i]):
|
||||
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight']
|
||||
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias']
|
||||
if any("upsample" in layer for layer in up_blocks[i]):
|
||||
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
||||
f"decoder.up.{i}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
|
||||
f"decoder.up.{i}.upsample.conv.bias"
|
||||
]
|
||||
|
||||
if any('block' in layer for layer in up_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any("block" in layer for layer in up_blocks[i]):
|
||||
num_blocks = len(
|
||||
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
|
||||
)
|
||||
blocks = {
|
||||
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
|
||||
}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['layers_per_block'] + 1):
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
for j in range(config["layers_per_block"] + 1):
|
||||
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
|
||||
if any('attn' in layer for layer in up_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any("attn" in layer for layer in up_blocks[i]):
|
||||
num_attn = len(
|
||||
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
|
||||
)
|
||||
attns = {
|
||||
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
|
||||
}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['layers_per_block'] + 1):
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
for j in range(config["layers_per_block"] + 1):
|
||||
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
|
||||
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
|
||||
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
|
||||
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
|
||||
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
|
||||
if "quantize.embedding.weight" in checkpoint:
|
||||
|
@ -321,9 +395,7 @@ if __name__ == "__main__":
|
|||
help="The config json file corresponding to the architecture.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LDMPipeline
|
||||
|
||||
from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
|
@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1):
|
|||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
"""
|
||||
if n_shave_prefix_segments >= 0:
|
||||
return '.'.join(path.split('.')[n_shave_prefix_segments:])
|
||||
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||
else:
|
||||
return '.'.join(path.split('.')[:n_shave_prefix_segments])
|
||||
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||
|
||||
|
||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
|
@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
|||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item.replace('in_layers.0', 'norm1')
|
||||
new_item = new_item.replace('in_layers.2', 'conv1')
|
||||
new_item = old_item.replace("in_layers.0", "norm1")
|
||||
new_item = new_item.replace("in_layers.2", "conv1")
|
||||
|
||||
new_item = new_item.replace('out_layers.0', 'norm2')
|
||||
new_item = new_item.replace('out_layers.3', 'conv2')
|
||||
new_item = new_item.replace("out_layers.0", "norm2")
|
||||
new_item = new_item.replace("out_layers.3", "conv2")
|
||||
|
||||
new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
|
||||
new_item = new_item.replace('skip_connection', 'conv_shortcut')
|
||||
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
||||
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({'old': old_item, 'new': new_item})
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
|||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||
new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({'old': old_item, 'new': new_item})
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
||||
def assign_to_checkpoint(
|
||||
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||
to them. It splits attention layers, and takes into account additional replacements
|
||||
|
@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
|||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map['query']] = query.reshape(target_shape)
|
||||
checkpoint[path_map['key']] = key.reshape(target_shape)
|
||||
checkpoint[path_map['value']] = value.reshape(target_shape)
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path['new']
|
||||
new_path = path["new"]
|
||||
|
||||
# These have already been assigned
|
||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace('middle_block.0', 'mid.resnets.0')
|
||||
new_path = new_path.replace('middle_block.1', 'mid.attentions.0')
|
||||
new_path = new_path.replace('middle_block.2', 'mid.resnets.1')
|
||||
new_path = new_path.replace("middle_block.0", "mid.resnets.0")
|
||||
new_path = new_path.replace("middle_block.1", "mid.attentions.0")
|
||||
new_path = new_path.replace("middle_block.2", "mid.resnets.1")
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
new_path = new_path.replace(replacement['old'], replacement['new'])
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path['old']]
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
def convert_ldm_checkpoint(checkpoint, config):
|
||||
|
@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||
"""
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['time_embed.0.weight']
|
||||
new_checkpoint['time_embedding.linear_1.bias'] = checkpoint['time_embed.0.bias']
|
||||
new_checkpoint['time_embedding.linear_2.weight'] = checkpoint['time_embed.2.weight']
|
||||
new_checkpoint['time_embedding.linear_2.bias'] = checkpoint['time_embed.2.bias']
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
|
||||
|
||||
new_checkpoint['conv_in.weight'] = checkpoint['input_blocks.0.0.weight']
|
||||
new_checkpoint['conv_in.bias'] = checkpoint['input_blocks.0.0.bias']
|
||||
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
|
||||
|
||||
new_checkpoint['conv_norm_out.weight'] = checkpoint['out.0.weight']
|
||||
new_checkpoint['conv_norm_out.bias'] = checkpoint['out.0.bias']
|
||||
new_checkpoint['conv_out.weight'] = checkpoint['out.2.weight']
|
||||
new_checkpoint['conv_out.bias'] = checkpoint['out.2.bias']
|
||||
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
|
||||
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
|
||||
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'input_blocks' in layer})
|
||||
input_blocks = {layer_id: [key for key in checkpoint if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)}
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'middle_block' in layer})
|
||||
middle_blocks = {layer_id: [key for key in checkpoint if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)}
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'output_blocks' in layer})
|
||||
output_blocks = {layer_id: [key for key in checkpoint if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)}
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_output_blocks)
|
||||
}
|
||||
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config['num_res_blocks'] + 1)
|
||||
layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1)
|
||||
block_id = (i - 1) // (config["num_res_blocks"] + 1)
|
||||
layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
|
||||
|
||||
resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key]
|
||||
attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key]
|
||||
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
|
||||
if f'input_blocks.{i}.0.op.weight' in checkpoint:
|
||||
new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.weight'] = checkpoint[f'input_blocks.{i}.0.op.weight']
|
||||
new_checkpoint[f'downsample_blocks.{block_id}.downsamplers.0.conv.bias'] = checkpoint[f'input_blocks.{i}.0.op.bias']
|
||||
if f"input_blocks.{i}.0.op.weight" in checkpoint:
|
||||
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
]
|
||||
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
|
||||
f"input_blocks.{i}.0.op.bias"
|
||||
]
|
||||
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config)
|
||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
|
||||
)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {'old': f'input_blocks.{i}.1', 'new': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}'}
|
||||
meta_path = {
|
||||
"old": f"input_blocks.{i}.1",
|
||||
"new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
to_split = {
|
||||
f'input_blocks.{i}.1.qkv.bias': {
|
||||
'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
|
||||
'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
|
||||
'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
|
||||
f"input_blocks.{i}.1.qkv.bias": {
|
||||
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
|
||||
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
|
||||
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
|
||||
},
|
||||
f'input_blocks.{i}.1.qkv.weight': {
|
||||
'key': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
|
||||
'query': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
|
||||
'value': f'downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
|
||||
f"input_blocks.{i}.1.qkv.weight": {
|
||||
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
|
||||
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
|
||||
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
|
||||
},
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
|
@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||
checkpoint,
|
||||
additional_replacements=[meta_path],
|
||||
attention_paths_to_split=to_split,
|
||||
config=config
|
||||
config=config,
|
||||
)
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
|
@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
to_split = {
|
||||
'middle_block.1.qkv.bias': {
|
||||
'key': 'mid_block.attentions.0.key.bias',
|
||||
'query': 'mid_block.attentions.0.query.bias',
|
||||
'value': 'mid_block.attentions.0.value.bias',
|
||||
"middle_block.1.qkv.bias": {
|
||||
"key": "mid_block.attentions.0.key.bias",
|
||||
"query": "mid_block.attentions.0.query.bias",
|
||||
"value": "mid_block.attentions.0.value.bias",
|
||||
},
|
||||
'middle_block.1.qkv.weight': {
|
||||
'key': 'mid_block.attentions.0.key.weight',
|
||||
'query': 'mid_block.attentions.0.query.weight',
|
||||
'value': 'mid_block.attentions.0.value.weight',
|
||||
"middle_block.1.qkv.weight": {
|
||||
"key": "mid_block.attentions.0.key.weight",
|
||||
"query": "mid_block.attentions.0.query.weight",
|
||||
"value": "mid_block.attentions.0.value.weight",
|
||||
},
|
||||
}
|
||||
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
|
||||
assign_to_checkpoint(
|
||||
attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
|
||||
)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config['num_res_blocks'] + 1)
|
||||
layer_in_block_id = i % (config['num_res_blocks'] + 1)
|
||||
block_id = i // (config["num_res_blocks"] + 1)
|
||||
layer_in_block_id = i % (config["num_res_blocks"] + 1)
|
||||
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||
output_block_list = {}
|
||||
|
||||
for layer in output_block_layers:
|
||||
layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1)
|
||||
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
||||
if layer_id in output_block_list:
|
||||
output_block_list[layer_id].append(layer_name)
|
||||
else:
|
||||
output_block_list[layer_id] = [layer_name]
|
||||
|
||||
if len(output_block_list) > 1:
|
||||
resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key]
|
||||
attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key]
|
||||
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
||||
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnets)
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
|
||||
|
||||
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
|
||||
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
|
||||
f"output_blocks.{i}.{index}.conv.bias"
|
||||
]
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
|
@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {
|
||||
'old': f'output_blocks.{i}.1',
|
||||
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
|
||||
"old": f"output_blocks.{i}.1",
|
||||
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
to_split = {
|
||||
f'output_blocks.{i}.1.qkv.bias': {
|
||||
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
|
||||
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
|
||||
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
|
||||
f"output_blocks.{i}.1.qkv.bias": {
|
||||
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
|
||||
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
|
||||
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
|
||||
},
|
||||
f'output_blocks.{i}.1.qkv.weight': {
|
||||
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
|
||||
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
|
||||
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
|
||||
f"output_blocks.{i}.1.qkv.weight": {
|
||||
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
|
||||
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
|
||||
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
|
||||
},
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
|
@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config):
|
|||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[meta_path],
|
||||
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None,
|
||||
attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
|
||||
config=config,
|
||||
)
|
||||
else:
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
for path in resnet_0_paths:
|
||||
old_path = '.'.join(['output_blocks', str(i), path['old']])
|
||||
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
|
||||
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
||||
|
||||
new_checkpoint[new_path] = checkpoint[old_path]
|
||||
|
||||
|
@ -303,9 +331,7 @@ if __name__ == "__main__":
|
|||
help="The config json file corresponding to the architecture.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DModel
|
||||
|
||||
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
|
||||
|
||||
|
||||
def convert_ncsnpp_checkpoint(checkpoint, config):
|
||||
|
|
|
@ -1,91 +1,127 @@
|
|||
from huggingface_hub import HfApi
|
||||
from transformers.file_utils import has_file
|
||||
from diffusers import UNet2DModel
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DModel
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
results = {}
|
||||
results["google_ddpm_cifar10_32"] = torch.tensor([-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
|
||||
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
|
||||
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
|
||||
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557])
|
||||
results["google_ddpm_ema_bedroom_256"] = torch.tensor([-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
|
||||
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
|
||||
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
|
||||
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365])
|
||||
results["CompVis_ldm_celebahq_256"] = torch.tensor([-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
|
||||
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
|
||||
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
|
||||
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943])
|
||||
results["google_ncsnpp_ffhq_1024"] = torch.tensor([ 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
|
||||
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
|
||||
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
|
||||
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505])
|
||||
results["google_ncsnpp_bedroom_256"] = torch.tensor([ 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
|
||||
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
|
||||
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
|
||||
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386])
|
||||
results["google_ncsnpp_celebahq_256"] = torch.tensor([ 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
|
||||
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
|
||||
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
|
||||
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431])
|
||||
results["google_ncsnpp_church_256"] = torch.tensor([ 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
|
||||
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
|
||||
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
|
||||
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390])
|
||||
results["google_ncsnpp_ffhq_256"] = torch.tensor([ 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
|
||||
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
|
||||
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
|
||||
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473])
|
||||
results["google_ddpm_cat_256"] = torch.tensor([-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
|
||||
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
|
||||
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
|
||||
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
|
||||
results["google_ddpm_celebahq_256"] = torch.tensor([-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
|
||||
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
|
||||
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
|
||||
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266])
|
||||
results["google_ddpm_ema_celebahq_256"] = torch.tensor([-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
|
||||
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
|
||||
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
|
||||
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355])
|
||||
results["google_ddpm_church_256"] = torch.tensor([-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
|
||||
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
|
||||
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
|
||||
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066])
|
||||
results["google_ddpm_bedroom_256"] = torch.tensor([-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
|
||||
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
|
||||
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
|
||||
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243])
|
||||
results["google_ddpm_ema_church_256"] = torch.tensor([-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
|
||||
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
|
||||
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
|
||||
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343])
|
||||
results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
|
||||
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
|
||||
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
|
||||
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219])
|
||||
# fmt: off
|
||||
results["google_ddpm_cifar10_32"] = torch.tensor([
|
||||
-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
|
||||
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
|
||||
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
|
||||
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557
|
||||
])
|
||||
results["google_ddpm_ema_bedroom_256"] = torch.tensor([
|
||||
-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
|
||||
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
|
||||
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
|
||||
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365
|
||||
])
|
||||
results["CompVis_ldm_celebahq_256"] = torch.tensor([
|
||||
-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
|
||||
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
|
||||
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
|
||||
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943
|
||||
])
|
||||
results["google_ncsnpp_ffhq_1024"] = torch.tensor([
|
||||
0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
|
||||
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
|
||||
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
|
||||
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505
|
||||
])
|
||||
results["google_ncsnpp_bedroom_256"] = torch.tensor([
|
||||
0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
|
||||
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
|
||||
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
|
||||
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386
|
||||
])
|
||||
results["google_ncsnpp_celebahq_256"] = torch.tensor([
|
||||
0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
|
||||
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
|
||||
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
|
||||
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431
|
||||
])
|
||||
results["google_ncsnpp_church_256"] = torch.tensor([
|
||||
0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
|
||||
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
|
||||
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
|
||||
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390
|
||||
])
|
||||
results["google_ncsnpp_ffhq_256"] = torch.tensor([
|
||||
0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
|
||||
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
|
||||
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
|
||||
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473
|
||||
])
|
||||
results["google_ddpm_cat_256"] = torch.tensor([
|
||||
-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
|
||||
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
|
||||
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
|
||||
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
|
||||
results["google_ddpm_celebahq_256"] = torch.tensor([
|
||||
-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
|
||||
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
|
||||
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
|
||||
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266
|
||||
])
|
||||
results["google_ddpm_ema_celebahq_256"] = torch.tensor([
|
||||
-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
|
||||
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
|
||||
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
|
||||
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355
|
||||
])
|
||||
results["google_ddpm_church_256"] = torch.tensor([
|
||||
-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
|
||||
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
|
||||
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
|
||||
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066
|
||||
])
|
||||
results["google_ddpm_bedroom_256"] = torch.tensor([
|
||||
-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
|
||||
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
|
||||
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
|
||||
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243
|
||||
])
|
||||
results["google_ddpm_ema_church_256"] = torch.tensor([
|
||||
-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
|
||||
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
|
||||
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
|
||||
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343
|
||||
])
|
||||
results["google_ddpm_ema_cat_256"] = torch.tensor([
|
||||
-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
|
||||
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
|
||||
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
|
||||
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
models = api.list_models(filter="diffusers")
|
||||
for mod in models:
|
||||
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
|
||||
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
|
||||
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
|
||||
|
||||
print(f"Started running {mod.modelId}!!!")
|
||||
|
||||
if mod.modelId.startswith("CompVis"):
|
||||
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
|
||||
else:
|
||||
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
|
||||
else:
|
||||
model = UNet2DModel.from_pretrained(local_checkpoint)
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
with torch.no_grad():
|
||||
logits = model(noise, time_step)['sample']
|
||||
logits = model(noise, time_step)["sample"]
|
||||
|
||||
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
|
||||
assert torch.allclose(
|
||||
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
|
||||
)
|
||||
print(f"{mod.modelId} has passed succesfully!!!")
|
||||
|
|
Loading…
Reference in New Issue