CompVis -> diffusers script - allow converting from merged checkpoint to either EMA or non-EMA (#991)

* improve script

* up
This commit is contained in:
Patrick von Platen 2022-10-26 12:32:07 +02:00 committed by GitHub
parent 0343d8f531
commit d9cfe325a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 4 deletions

View File

@ -285,15 +285,34 @@ def create_ldm_bert_config(original_config):
return config return config
def convert_ldm_unet_checkpoint(checkpoint, config): def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
""" """
Takes a state dict and a config, and returns a converted checkpoint. Takes a state dict and a config, and returns a converted checkpoint.
""" """
# extract state_dict for UNet # extract state_dict for UNet
unet_state_dict = {} unet_state_dict = {}
unet_key = "model.diffusion_model."
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys: for key in keys:
if key.startswith(unet_key): if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
@ -630,6 +649,15 @@ if __name__ == "__main__":
type=str, type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
) )
parser.add_argument(
"--extract_ema",
action="store_true",
help=(
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args() args = parser.parse_args()
@ -641,7 +669,9 @@ if __name__ == "__main__":
args.original_config_file = "./v1-inference.yaml" args.original_config_file = "./v1-inference.yaml"
original_config = OmegaConf.load(args.original_config_file) original_config = OmegaConf.load(args.original_config_file)
checkpoint = torch.load(args.checkpoint_path)["state_dict"]
checkpoint = torch.load(args.checkpoint_path)
checkpoint = checkpoint["state_dict"]
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start beta_start = original_config.model.params.linear_start
@ -669,7 +699,9 @@ if __name__ == "__main__":
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config) unet_config = create_unet_diffusers_config(original_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
)
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config)
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)