improve ddpm conversion script
This commit is contained in:
parent
cb90fd69b4
commit
3f0b44b322
|
@ -1,4 +1,4 @@
|
||||||
from diffusers import UNetUnconditionalModel
|
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
@ -56,7 +56,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||||
|
|
||||||
if attention_paths_to_split is not None:
|
if attention_paths_to_split is not None:
|
||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError(f"Please specify the config if setting 'attention_paths_to_split' to 'True'.")
|
raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
|
||||||
|
|
||||||
for path, path_map in attention_paths_to_split.items():
|
for path, path_map in attention_paths_to_split.items():
|
||||||
old_tensor = old_checkpoint[path]
|
old_tensor = old_checkpoint[path]
|
||||||
|
@ -86,7 +86,6 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||||
for replacement in additional_replacements:
|
for replacement in additional_replacements:
|
||||||
new_path = new_path.replace(replacement['old'], replacement['new'])
|
new_path = new_path.replace(replacement['old'], replacement['new'])
|
||||||
|
|
||||||
|
|
||||||
if 'attentions' in new_path:
|
if 'attentions' in new_path:
|
||||||
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
|
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
|
||||||
else:
|
else:
|
||||||
|
@ -97,7 +96,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||||
"""
|
"""
|
||||||
Takes a state dict and a config, and returns a converted checkpoint.
|
Takes a state dict and a config, and returns a converted checkpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
|
|
||||||
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight']
|
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight']
|
||||||
|
@ -121,7 +119,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||||
|
|
||||||
for i in range(num_downsample_blocks):
|
for i in range(num_downsample_blocks):
|
||||||
block_id = (i - 1) // (config['num_res_blocks'] + 1)
|
block_id = (i - 1) // (config['num_res_blocks'] + 1)
|
||||||
layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1)
|
|
||||||
|
|
||||||
if any('downsample' in layer for layer in downsample_blocks[i]):
|
if any('downsample' in layer for layer in downsample_blocks[i]):
|
||||||
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
||||||
|
@ -138,7 +135,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||||
paths = renew_resnet_paths(blocks[j])
|
paths = renew_resnet_paths(blocks[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
||||||
|
|
||||||
|
|
||||||
if any('attn' in layer for layer in downsample_blocks[i]):
|
if any('attn' in layer for layer in downsample_blocks[i]):
|
||||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
|
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
|
||||||
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||||
|
@ -148,7 +144,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||||
paths = renew_attention_paths(attns[j])
|
paths = renew_attention_paths(attns[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
||||||
|
|
||||||
|
|
||||||
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
|
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
|
||||||
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
|
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
|
||||||
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
|
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
|
||||||
|
@ -186,7 +181,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||||
paths = renew_resnet_paths(blocks[j])
|
paths = renew_resnet_paths(blocks[j])
|
||||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||||
|
|
||||||
|
|
||||||
if any('attn' in layer for layer in upsample_blocks[i]):
|
if any('attn' in layer for layer in upsample_blocks[i]):
|
||||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
|
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
|
||||||
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||||
|
@ -220,12 +214,21 @@ if __name__ == "__main__":
|
||||||
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
checkpoint = torch.load(args.checkpoint_path)
|
checkpoint = torch.load(args.checkpoint_path)
|
||||||
|
|
||||||
with open(args.config_file) as f:
|
with open(args.config_file) as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
|
|
||||||
converted_checkpoint = convert_ddpm_checkpoint(args.checkpoint_path, args.config_file)
|
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
|
||||||
torch.save(converted_checkpoint, args.dump_path)
|
|
||||||
|
if "ddpm" in config:
|
||||||
|
del config["ddpm"]
|
||||||
|
|
||||||
|
model = UNetUnconditionalModel(**config)
|
||||||
|
model.load_state_dict(converted_checkpoint)
|
||||||
|
|
||||||
|
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
|
||||||
|
|
||||||
|
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
|
||||||
|
pipe.save_pretrained(args.dump_path)
|
||||||
|
|
Loading…
Reference in New Issue