diff --git a/scripts/change_naming_configs_and_checkpoints.py b/scripts/change_naming_configs_and_checkpoints.py new file mode 100644 index 00000000..20e1d5c7 --- /dev/null +++ b/scripts/change_naming_configs_and_checkpoints.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the LDM checkpoints. """ + +import argparse +import os +import json +import torch +from diffusers import UNet2DModel, UNet2DConditionModel +from transformers.file_utils import has_file + +do_only_config = False +do_only_weights = True +do_only_renaming = False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo_path", + default=None, + type=str, + required=True, + help="The config json file corresponding to the architecture.", + ) + + parser.add_argument( + "--dump_path", default=None, type=str, required=True, help="Path to the output model." + ) + + args = parser.parse_args() + + config_parameters_to_change = { + "image_size": "sample_size", + "num_res_blocks": "layers_per_block", + "block_channels": "block_out_channels", + "down_blocks": "down_block_types", + "up_blocks": "up_block_types", + "downscale_freq_shift": "freq_shift", + "resnet_num_groups": "norm_num_groups", + "resnet_act_fn": "act_fn", + "resnet_eps": "norm_eps", + "num_head_channels": "attention_head_dim", + } + + key_parameters_to_change = { + "time_steps": "time_proj", + "mid": "mid_block", + "downsample_blocks": "down_blocks", + "upsample_blocks": "up_blocks", + } + + subfolder = "" if has_file(args.repo_path, "config.json") else "unet" + + with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader: + text = reader.read() + config = json.loads(text) + + if do_only_config: + for key in config_parameters_to_change.keys(): + config.pop(key, None) + + if has_file(args.repo_path, "config.json"): + model = UNet2DModel(**config) + else: + class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel + model = class_name(**config) + + if do_only_config: + model.save_config(os.path.join(args.repo_path, subfolder)) + + config = dict(model.config) + + if do_only_renaming: + for key, value in config_parameters_to_change.items(): + if key in config: + config[value] = config[key] + del config[key] + + config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]] + config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]] + + if do_only_weights: + state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin")) + + new_state_dict = {} + for param_key, param_value in state_dict.items(): + if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"): + continue + has_changed = False + for key, new_key in key_parameters_to_change.items(): + if not has_changed and param_key.split(".")[0] == key: + new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value + has_changed = True + if not has_changed: + new_state_dict[param_key] = param_value + + model.load_state_dict(new_state_dict) + model.save_pretrained(os.path.join(args.repo_path, subfolder)) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 71cb9b73..991a078f 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -48,6 +48,7 @@ class ConfigMixin: """ config_name = None + ignore_for_config = [] def register_to_config(self, **kwargs): if self.config_name is None: @@ -212,6 +213,9 @@ class ConfigMixin: # remove general kwargs if present in dict if "kwargs" in expected_keys: expected_keys.remove("kwargs") + # remove keys to be ignored + if len(cls.ignore_for_config) > 0: + expected_keys = expected_keys - set(cls.ignore_for_config) init_dict = {} for key in expected_keys: if key in kwargs: