improve ddpm conversion script

This commit is contained in:
Patrick von Platen 2022-07-19 11:24:13 +00:00
parent cb90fd69b4
commit 3f0b44b322
1 changed files with 14 additions and 11 deletions

View File

@ -1,4 +1,4 @@
from diffusers import UNetUnconditionalModel from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
import argparse import argparse
import json import json
import torch import torch
@ -56,7 +56,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
if attention_paths_to_split is not None: if attention_paths_to_split is not None:
if config is None: if config is None:
raise ValueError(f"Please specify the config if setting 'attention_paths_to_split' to 'True'.") raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
for path, path_map in attention_paths_to_split.items(): for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path] old_tensor = old_checkpoint[path]
@ -86,7 +86,6 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for replacement in additional_replacements: for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new']) new_path = new_path.replace(replacement['old'], replacement['new'])
if 'attentions' in new_path: if 'attentions' in new_path:
checkpoint[new_path] = old_checkpoint[path['old']].squeeze() checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
else: else:
@ -97,7 +96,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
""" """
Takes a state dict and a config, and returns a converted checkpoint. Takes a state dict and a config, and returns a converted checkpoint.
""" """
new_checkpoint = {} new_checkpoint = {}
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight'] new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight']
@ -121,7 +119,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
for i in range(num_downsample_blocks): for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1) block_id = (i - 1) // (config['num_res_blocks'] + 1)
layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1)
if any('downsample' in layer for layer in downsample_blocks[i]): if any('downsample' in layer for layer in downsample_blocks[i]):
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
@ -138,7 +135,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint) assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in downsample_blocks[i]): if any('attn' in layer for layer in downsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer}) num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
@ -148,7 +144,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attns[j]) paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key] mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key] mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key] mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
@ -186,7 +181,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in upsample_blocks[i]): if any('attn' in layer for layer in upsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer}) num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
@ -220,12 +214,21 @@ if __name__ == "__main__":
"--dump_path", default=None, type=str, required=True, help="Path to the output model." "--dump_path", default=None, type=str, required=True, help="Path to the output model."
) )
args = parser.parse_args() args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path) checkpoint = torch.load(args.checkpoint_path)
with open(args.config_file) as f: with open(args.config_file) as f:
config = json.loads(f.read()) config = json.loads(f.read())
converted_checkpoint = convert_ddpm_checkpoint(args.checkpoint_path, args.config_file) converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
torch.save(converted_checkpoint, args.dump_path)
if "ddpm" in config:
del config["ddpm"]
model = UNetUnconditionalModel(**config)
model.load_state_dict(converted_checkpoint)
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)