From d9cfe325a53502641f16ce4f839391c5b0d0a684 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Oct 2022 12:32:07 +0200 Subject: [PATCH] CompVis -> diffusers script - allow converting from merged checkpoint to either EMA or non-EMA (#991) * improve script * up --- ..._original_stable_diffusion_to_diffusers.py | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index db1b3073..46073001 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -285,15 +285,34 @@ def create_ldm_bert_config(original_config): return config -def convert_ldm_unet_checkpoint(checkpoint, config): +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): """ Takes a state dict and a config, and returns a converted checkpoint. """ # extract state_dict for UNet unet_state_dict = {} - unet_key = "model.diffusion_model." keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + if extract_ema: + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) @@ -630,6 +649,15 @@ if __name__ == "__main__": type=str, help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") args = parser.parse_args() @@ -641,7 +669,9 @@ if __name__ == "__main__": args.original_config_file = "./v1-inference.yaml" original_config = OmegaConf.load(args.original_config_file) - checkpoint = torch.load(args.checkpoint_path)["state_dict"] + + checkpoint = torch.load(args.checkpoint_path) + checkpoint = checkpoint["state_dict"] num_train_timesteps = original_config.model.params.timesteps beta_start = original_config.model.params.linear_start @@ -669,7 +699,9 @@ if __name__ == "__main__": # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema + ) unet = UNet2DConditionModel(**unet_config) unet.load_state_dict(converted_unet_checkpoint)