diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 9ccfc254..675ba4a9 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -315,26 +315,26 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100: + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: print(f"Checkpoint {path} has both EMA and non-EMA weights.") - if extract_ema: - print( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: print( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {}