diff --git a/debug_conversion.py b/debug_conversion.py index a8be53e8..a32ce784 100755 --- a/debug_conversion.py +++ b/debug_conversion.py @@ -1,86 +1,110 @@ #!/usr/bin/env python3 import json import os + +from regex import P from diffusers import UNetUnconditionalModel -from scripts.convert_ldm_original_checkpoint_to_diffusers import convert_ldm_checkpoint +from scripts.convert_ncsnpp_original_checkpoint_to_diffusers import convert_ncsnpp_checkpoint from huggingface_hub import hf_hub_download import torch -model_id = "fusing/latent-diffusion-celeba-256" -subfolder = "unet" -#model_id = "fusing/unet-ldm-dummy" -#subfolder = None - -checkpoint = "diffusion_model.pt" -config = "config.json" - -if subfolder is not None: - checkpoint = os.path.join(subfolder, checkpoint) - config = os.path.join(subfolder, config) - -original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint)) -config_path = hf_hub_download(model_id, config) - -with open(config_path) as f: - config = json.load(f) - -checkpoint = convert_ldm_checkpoint(original_checkpoint, config) -def current_codebase_conversion(): - model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, ldm=True) - model.eval() +def convert_checkpoint(model_id, subfolder=None, checkpoint = "diffusion_model.pt", config = "config.json"): + if subfolder is not None: + checkpoint = os.path.join(subfolder, checkpoint) + config = os.path.join(subfolder, config) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint),map_location='cpu') + config_path = hf_hub_download(model_id, config) - noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) - time_step = torch.tensor([10] * noise.shape[0]) + with open(config_path) as f: + config = json.load(f) - with torch.no_grad(): - output = model(noise, time_step) - - return model.state_dict() + checkpoint = convert_ncsnpp_checkpoint(original_checkpoint, config) -currently_converted_checkpoint = current_codebase_conversion() + def current_codebase_conversion(path): + model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, sde=True) + model.eval() + model.config.sde=False + model.save_config(path) + model.config.sde=True + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + time_step = torch.tensor([10] * noise.shape[0]) + + with torch.no_grad(): + output = model(noise, time_step) + + return model.state_dict() + + path = f"{model_id}_converted" + currently_converted_checkpoint = current_codebase_conversion(path) -def diff_between_checkpoints(ch_0, ch_1): - all_layers_included = False + def diff_between_checkpoints(ch_0, ch_1): + all_layers_included = False - if not set(ch_0.keys()) == set(ch_1.keys()): - print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})") - for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))): - print(f"\t{key}") + if not set(ch_0.keys()) == set(ch_1.keys()): + print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})") + for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))): + print(f"\t{key}") - print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})") - for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))): - print(f"\t{key}") - else: - print("Keys are the same between the two checkpoints") - all_layers_included = True - - keys = ch_0.keys() - non_equal_keys = [] - - if all_layers_included: - for key in keys: - try: - if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()): - non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}') - - except RuntimeError as e: - print(e) - non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}') - - if len(non_equal_keys): - non_equal_keys = '\n\t'.join(non_equal_keys) - print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}") + print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})") + for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))): + print(f"\t{key}") else: - print("All keys are equal across checkpoints.") + print("Keys are the same between the two checkpoints") + all_layers_included = True + + keys = ch_0.keys() + non_equal_keys = [] + + if all_layers_included: + for key in keys: + try: + if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()): + non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}') + + except RuntimeError as e: + print(e) + non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}') + + if len(non_equal_keys): + non_equal_keys = '\n\t'.join(non_equal_keys) + print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}") + else: + print("All keys are equal across checkpoints.") -diff_between_checkpoints(currently_converted_checkpoint, checkpoint) -torch.save(checkpoint, "/path/to/checkpoint/") + diff_between_checkpoints(currently_converted_checkpoint, checkpoint) + os.makedirs( f"{model_id}_converted",exist_ok =True) + torch.save(checkpoint, f"{model_id}_converted/diffusion_model.pt") + + +model_ids = ["fusing/ffhq_ncsnpp","fusing/church_256-ncsnpp-ve", "fusing/celebahq_256-ncsnpp-ve", + "fusing/bedroom_256-ncsnpp-ve","fusing/ffhq_256-ncsnpp-ve","fusing/ncsnpp-ffhq-ve-dummy" + ] +for model in model_ids: + print(f"converting {model}") + try: + convert_checkpoint(model) + except Exception as e: + print(e) + +from tests.test_modeling_utils import PipelineTesterMixin, NCSNppModelTests + +tester1 = NCSNppModelTests() +tester2 = PipelineTesterMixin() + +os.environ["RUN_SLOW"] = '1' +cmd = "export RUN_SLOW=1; echo $RUN_SLOW" # or whatever command +os.system(cmd) +tester2.test_score_sde_ve_pipeline(f"{model_ids[0]}_converted") +tester1.test_output_pretrained_ve_mid(f"{model_ids[2]}_converted") +tester1.test_output_pretrained_ve_large(f"{model_ids[-1]}_converted") diff --git a/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py b/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py new file mode 100644 index 00000000..79bdb560 --- /dev/null +++ b/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the NCSNPP checkpoints. """ + +import argparse +import json +import torch +from diffusers import UNetUnconditionalModel + + + +def convert_ncsnpp_checkpoint(checkpoint, config): + """ + Takes a state dict and the path to + """ + new_model_architecture = UNetUnconditionalModel(**config) + new_model_architecture.time_steps.W.data= checkpoint['all_modules.0.W'].data + new_model_architecture.time_steps.weight.data = checkpoint['all_modules.0.W'].data + new_model_architecture.time_embedding.linear_1.weight.data = checkpoint['all_modules.1.weight'].data + new_model_architecture.time_embedding.linear_1.bias.data = checkpoint['all_modules.1.bias'].data + + new_model_architecture.time_embedding.linear_2.weight.data = checkpoint['all_modules.2.weight'].data + new_model_architecture.time_embedding.linear_2.bias.data= checkpoint['all_modules.2.bias'].data + + new_model_architecture.conv_in.weight.data = checkpoint['all_modules.3.weight'].data + new_model_architecture.conv_in.bias.data = checkpoint['all_modules.3.bias'].data + + new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data + new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data + new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data + new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data + + module_index = 4 + + + def set_attention_weights(new_layer,old_checkpoint,index): + new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T + new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T + new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T + + new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data + new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data + new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data + + new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T + new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data + + new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data + new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data + + def set_resnet_weights(new_layer,old_checkpoint,index): + new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data + new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data + new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data + new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data + + new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data + new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data + new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data + new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data + + new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data + new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data + + if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down: + new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data + new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data + + for i, block in enumerate(new_model_architecture.downsample_blocks): + has_attentions = hasattr(block, "attentions") + for j in range(len(block.resnets)): + set_resnet_weights(block.resnets[j],checkpoint, module_index) + module_index += 1 + if has_attentions: + set_attention_weights(block.attentions[j],checkpoint, module_index) + module_index += 1 + + if hasattr(block, "downsamplers") and block.downsamplers is not None: + set_resnet_weights(block.resnet_down,checkpoint, module_index) + module_index += 1 + block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data + block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data + module_index += 1 + + + + set_resnet_weights(new_model_architecture.mid.resnets[0],checkpoint,module_index) + module_index += 1 + set_attention_weights(new_model_architecture.mid.attentions[0],checkpoint, module_index) + module_index += 1 + set_resnet_weights(new_model_architecture.mid.resnets[1],checkpoint,module_index) + module_index += 1 + + for i, block in enumerate(new_model_architecture.upsample_blocks): + has_attentions = hasattr(block, "attentions") + for j in range(len(block.resnets)): + set_resnet_weights(block.resnets[j],checkpoint, module_index) + module_index += 1 + if has_attentions: + set_attention_weights(block.attentions[0],checkpoint, module_index) # why can there only be a single attention layer for up? + module_index += 1 + + if hasattr(block, "resnet_up") and block.resnet_up is not None: + block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data + block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data + module_index += 1 + block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data + block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data + module_index += 1 + set_resnet_weights(block.resnet_up,checkpoint, module_index) + module_index += 1 + + new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data + new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data + module_index += 1 + new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data + new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data + + return new_model_architecture.state_dict() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt", type=str, required=False, help="Path to the checkpoint to convert." + ) + + parser.add_argument( + "--config_file", + default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json", + type=str, + required=False, + help="The config json file corresponding to the architecture.", + ) + + parser.add_argument( + "--dump_path", default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt", type=str, required=False, help="Path to the output model." + ) + + args = parser.parse_args() + + + + + checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + + with open(args.config_file) as f: + config = json.loads(f.read()) + + + converted_checkpoint = convert_ncsnpp_checkpoint(checkpoint, config,) + torch.save(converted_checkpoint, args.dump_path) diff --git a/src/diffusers/models/unet_unconditional.py b/src/diffusers/models/unet_unconditional.py index cee54e37..60cac787 100644 --- a/src/diffusers/models/unet_unconditional.py +++ b/src/diffusers/models/unet_unconditional.py @@ -152,6 +152,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): progressive_input="input_skip", resnet_num_groups=32, continuous=True, + **kwargs, ): super().__init__() # register all __init__ params to be accessible via `self.config.<...>` @@ -454,7 +455,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): # 5. up skip_sample = None for upsample_block in self.upsample_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]