[Conversion] Make sure ema weights are extracted correctly (#1937)

* [Conversion] Make sure ema weights are extracted correctly

* up

* finish
This commit is contained in:
Patrick von Platen 2023-01-06 10:08:39 +04:00 committed by GitHub
parent 2533f92532
commit 409387889d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 14 deletions

View File

@ -315,26 +315,26 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model." unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100: if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
print(f"Checkpoint {path} has both EMA and non-EMA weights.") print(f"Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema: print(
print( "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." )
) for key in keys:
for key in keys: if key.startswith("model.diffusion_model"):
if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else:
else: if sum(k.startswith("model_ema") for k in keys) > 100:
print( print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag." " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
) )
for key in keys: for key in keys:
if key.startswith(unet_key): if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {} new_checkpoint = {}