Conversion script for ncsnpp models (#98)
* added kwargs for easier intialisation of random model * initial commit for conversion script * current debug script * update * Update * done * add updated debug conversion script * style * clean conversion script
This commit is contained in:
parent
182b164f32
commit
f794432e81
|
@ -1,86 +1,110 @@
|
|||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
|
||||
from regex import P
|
||||
from diffusers import UNetUnconditionalModel
|
||||
from scripts.convert_ldm_original_checkpoint_to_diffusers import convert_ldm_checkpoint
|
||||
from scripts.convert_ncsnpp_original_checkpoint_to_diffusers import convert_ncsnpp_checkpoint
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
|
||||
model_id = "fusing/latent-diffusion-celeba-256"
|
||||
subfolder = "unet"
|
||||
#model_id = "fusing/unet-ldm-dummy"
|
||||
#subfolder = None
|
||||
|
||||
checkpoint = "diffusion_model.pt"
|
||||
config = "config.json"
|
||||
|
||||
if subfolder is not None:
|
||||
checkpoint = os.path.join(subfolder, checkpoint)
|
||||
config = os.path.join(subfolder, config)
|
||||
|
||||
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint))
|
||||
config_path = hf_hub_download(model_id, config)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
checkpoint = convert_ldm_checkpoint(original_checkpoint, config)
|
||||
|
||||
|
||||
def current_codebase_conversion():
|
||||
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, ldm=True)
|
||||
model.eval()
|
||||
def convert_checkpoint(model_id, subfolder=None, checkpoint = "diffusion_model.pt", config = "config.json"):
|
||||
if subfolder is not None:
|
||||
checkpoint = os.path.join(subfolder, checkpoint)
|
||||
config = os.path.join(subfolder, config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint),map_location='cpu')
|
||||
config_path = hf_hub_download(model_id, config)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
return model.state_dict()
|
||||
checkpoint = convert_ncsnpp_checkpoint(original_checkpoint, config)
|
||||
|
||||
|
||||
currently_converted_checkpoint = current_codebase_conversion()
|
||||
def current_codebase_conversion(path):
|
||||
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, sde=True)
|
||||
model.eval()
|
||||
model.config.sde=False
|
||||
model.save_config(path)
|
||||
model.config.sde=True
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
return model.state_dict()
|
||||
|
||||
path = f"{model_id}_converted"
|
||||
currently_converted_checkpoint = current_codebase_conversion(path)
|
||||
|
||||
|
||||
def diff_between_checkpoints(ch_0, ch_1):
|
||||
all_layers_included = False
|
||||
def diff_between_checkpoints(ch_0, ch_1):
|
||||
all_layers_included = False
|
||||
|
||||
if not set(ch_0.keys()) == set(ch_1.keys()):
|
||||
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
|
||||
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
|
||||
print(f"\t{key}")
|
||||
if not set(ch_0.keys()) == set(ch_1.keys()):
|
||||
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
|
||||
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
|
||||
print(f"\t{key}")
|
||||
|
||||
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
|
||||
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
|
||||
print(f"\t{key}")
|
||||
else:
|
||||
print("Keys are the same between the two checkpoints")
|
||||
all_layers_included = True
|
||||
|
||||
keys = ch_0.keys()
|
||||
non_equal_keys = []
|
||||
|
||||
if all_layers_included:
|
||||
for key in keys:
|
||||
try:
|
||||
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
|
||||
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
|
||||
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
|
||||
|
||||
if len(non_equal_keys):
|
||||
non_equal_keys = '\n\t'.join(non_equal_keys)
|
||||
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
|
||||
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
|
||||
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
|
||||
print(f"\t{key}")
|
||||
else:
|
||||
print("All keys are equal across checkpoints.")
|
||||
print("Keys are the same between the two checkpoints")
|
||||
all_layers_included = True
|
||||
|
||||
keys = ch_0.keys()
|
||||
non_equal_keys = []
|
||||
|
||||
if all_layers_included:
|
||||
for key in keys:
|
||||
try:
|
||||
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
|
||||
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
|
||||
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
|
||||
|
||||
if len(non_equal_keys):
|
||||
non_equal_keys = '\n\t'.join(non_equal_keys)
|
||||
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
|
||||
else:
|
||||
print("All keys are equal across checkpoints.")
|
||||
|
||||
|
||||
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)
|
||||
torch.save(checkpoint, "/path/to/checkpoint/")
|
||||
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)
|
||||
os.makedirs( f"{model_id}_converted",exist_ok =True)
|
||||
torch.save(checkpoint, f"{model_id}_converted/diffusion_model.pt")
|
||||
|
||||
|
||||
model_ids = ["fusing/ffhq_ncsnpp","fusing/church_256-ncsnpp-ve", "fusing/celebahq_256-ncsnpp-ve",
|
||||
"fusing/bedroom_256-ncsnpp-ve","fusing/ffhq_256-ncsnpp-ve","fusing/ncsnpp-ffhq-ve-dummy"
|
||||
]
|
||||
for model in model_ids:
|
||||
print(f"converting {model}")
|
||||
try:
|
||||
convert_checkpoint(model)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
from tests.test_modeling_utils import PipelineTesterMixin, NCSNppModelTests
|
||||
|
||||
tester1 = NCSNppModelTests()
|
||||
tester2 = PipelineTesterMixin()
|
||||
|
||||
os.environ["RUN_SLOW"] = '1'
|
||||
cmd = "export RUN_SLOW=1; echo $RUN_SLOW" # or whatever command
|
||||
os.system(cmd)
|
||||
tester2.test_score_sde_ve_pipeline(f"{model_ids[0]}_converted")
|
||||
tester1.test_output_pretrained_ve_mid(f"{model_ids[2]}_converted")
|
||||
tester1.test_output_pretrained_ve_large(f"{model_ids[-1]}_converted")
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Conversion script for the NCSNPP checkpoints. """
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
from diffusers import UNetUnconditionalModel
|
||||
|
||||
|
||||
|
||||
def convert_ncsnpp_checkpoint(checkpoint, config):
|
||||
"""
|
||||
Takes a state dict and the path to
|
||||
"""
|
||||
new_model_architecture = UNetUnconditionalModel(**config)
|
||||
new_model_architecture.time_steps.W.data= checkpoint['all_modules.0.W'].data
|
||||
new_model_architecture.time_steps.weight.data = checkpoint['all_modules.0.W'].data
|
||||
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint['all_modules.1.weight'].data
|
||||
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint['all_modules.1.bias'].data
|
||||
|
||||
new_model_architecture.time_embedding.linear_2.weight.data = checkpoint['all_modules.2.weight'].data
|
||||
new_model_architecture.time_embedding.linear_2.bias.data= checkpoint['all_modules.2.bias'].data
|
||||
|
||||
new_model_architecture.conv_in.weight.data = checkpoint['all_modules.3.weight'].data
|
||||
new_model_architecture.conv_in.bias.data = checkpoint['all_modules.3.bias'].data
|
||||
|
||||
new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data
|
||||
new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data
|
||||
new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data
|
||||
new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data
|
||||
|
||||
module_index = 4
|
||||
|
||||
|
||||
def set_attention_weights(new_layer,old_checkpoint,index):
|
||||
new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T
|
||||
new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T
|
||||
new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T
|
||||
|
||||
new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data
|
||||
new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data
|
||||
new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data
|
||||
|
||||
new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T
|
||||
new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data
|
||||
|
||||
new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
|
||||
new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data
|
||||
|
||||
def set_resnet_weights(new_layer,old_checkpoint,index):
|
||||
new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data
|
||||
new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data
|
||||
new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
|
||||
new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data
|
||||
|
||||
new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data
|
||||
new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data
|
||||
new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data
|
||||
new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data
|
||||
|
||||
new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data
|
||||
new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data
|
||||
|
||||
if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down:
|
||||
new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data
|
||||
new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data
|
||||
|
||||
for i, block in enumerate(new_model_architecture.downsample_blocks):
|
||||
has_attentions = hasattr(block, "attentions")
|
||||
for j in range(len(block.resnets)):
|
||||
set_resnet_weights(block.resnets[j],checkpoint, module_index)
|
||||
module_index += 1
|
||||
if has_attentions:
|
||||
set_attention_weights(block.attentions[j],checkpoint, module_index)
|
||||
module_index += 1
|
||||
|
||||
if hasattr(block, "downsamplers") and block.downsamplers is not None:
|
||||
set_resnet_weights(block.resnet_down,checkpoint, module_index)
|
||||
module_index += 1
|
||||
block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data
|
||||
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
|
||||
module_index += 1
|
||||
|
||||
|
||||
|
||||
set_resnet_weights(new_model_architecture.mid.resnets[0],checkpoint,module_index)
|
||||
module_index += 1
|
||||
set_attention_weights(new_model_architecture.mid.attentions[0],checkpoint, module_index)
|
||||
module_index += 1
|
||||
set_resnet_weights(new_model_architecture.mid.resnets[1],checkpoint,module_index)
|
||||
module_index += 1
|
||||
|
||||
for i, block in enumerate(new_model_architecture.upsample_blocks):
|
||||
has_attentions = hasattr(block, "attentions")
|
||||
for j in range(len(block.resnets)):
|
||||
set_resnet_weights(block.resnets[j],checkpoint, module_index)
|
||||
module_index += 1
|
||||
if has_attentions:
|
||||
set_attention_weights(block.attentions[0],checkpoint, module_index) # why can there only be a single attention layer for up?
|
||||
module_index += 1
|
||||
|
||||
if hasattr(block, "resnet_up") and block.resnet_up is not None:
|
||||
block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
|
||||
block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
|
||||
module_index += 1
|
||||
block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
|
||||
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
|
||||
module_index += 1
|
||||
set_resnet_weights(block.resnet_up,checkpoint, module_index)
|
||||
module_index += 1
|
||||
|
||||
new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
|
||||
new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
|
||||
module_index += 1
|
||||
new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
|
||||
new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
|
||||
|
||||
return new_model_architecture.state_dict()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt", type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The config json file corresponding to the architecture.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dump_path", default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt", type=str, required=False, help="Path to the output model."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||
|
||||
with open(args.config_file) as f:
|
||||
config = json.loads(f.read())
|
||||
|
||||
|
||||
converted_checkpoint = convert_ncsnpp_checkpoint(checkpoint, config,)
|
||||
torch.save(converted_checkpoint, args.dump_path)
|
|
@ -152,6 +152,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
|||
progressive_input="input_skip",
|
||||
resnet_num_groups=32,
|
||||
continuous=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
# register all __init__ params to be accessible via `self.config.<...>`
|
||||
|
@ -454,7 +455,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
|||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.upsample_blocks:
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
|
|
Loading…
Reference in New Issue