From 3f0b44b3223b0b669693d2be5440a8a3ab90570c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 19 Jul 2022 11:24:13 +0000 Subject: [PATCH] improve ddpm conversion script --- ...t_ddpm_original_checkpoint_to_diffusers.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py index 289fea29..b1499e28 100644 --- a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py @@ -1,4 +1,4 @@ -from diffusers import UNetUnconditionalModel +from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline import argparse import json import torch @@ -56,7 +56,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s if attention_paths_to_split is not None: if config is None: - raise ValueError(f"Please specify the config if setting 'attention_paths_to_split' to 'True'.") + raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.") for path, path_map in attention_paths_to_split.items(): old_tensor = old_checkpoint[path] @@ -86,7 +86,6 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s for replacement in additional_replacements: new_path = new_path.replace(replacement['old'], replacement['new']) - if 'attentions' in new_path: checkpoint[new_path] = old_checkpoint[path['old']].squeeze() else: @@ -97,7 +96,6 @@ def convert_ddpm_checkpoint(checkpoint, config): """ Takes a state dict and a config, and returns a converted checkpoint. """ - new_checkpoint = {} new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight'] @@ -121,7 +119,6 @@ def convert_ddpm_checkpoint(checkpoint, config): for i in range(num_downsample_blocks): block_id = (i - 1) // (config['num_res_blocks'] + 1) - layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1) if any('downsample' in layer for layer in downsample_blocks[i]): new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] @@ -138,7 +135,6 @@ def convert_ddpm_checkpoint(checkpoint, config): paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint) - if any('attn' in layer for layer in downsample_blocks[i]): num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer}) attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} @@ -148,7 +144,6 @@ def convert_ddpm_checkpoint(checkpoint, config): paths = renew_attention_paths(attns[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) - mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key] mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key] mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key] @@ -186,7 +181,6 @@ def convert_ddpm_checkpoint(checkpoint, config): paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) - if any('attn' in layer for layer in upsample_blocks[i]): num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer}) attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} @@ -220,12 +214,21 @@ if __name__ == "__main__": "--dump_path", default=None, type=str, required=True, help="Path to the output model." ) - args = parser.parse_args() checkpoint = torch.load(args.checkpoint_path) with open(args.config_file) as f: config = json.loads(f.read()) - converted_checkpoint = convert_ddpm_checkpoint(args.checkpoint_path, args.config_file) - torch.save(converted_checkpoint, args.dump_path) + converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config) + + if "ddpm" in config: + del config["ddpm"] + + model = UNetUnconditionalModel(**config) + model.load_state_dict(converted_checkpoint) + + scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) + + pipe = DDPMPipeline(unet=model, scheduler=scheduler) + pipe.save_pretrained(args.dump_path)