[Conversion] Make sure ema weights are extracted correctly (#1937)
* [Conversion] Make sure ema weights are extracted correctly * up * finish
This commit is contained in:
parent
2533f92532
commit
409387889d
|
@ -315,26 +315,26 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||||
|
|
||||||
unet_key = "model.diffusion_model."
|
unet_key = "model.diffusion_model."
|
||||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||||
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
if extract_ema:
|
print(
|
||||||
print(
|
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
)
|
||||||
)
|
for key in keys:
|
||||||
for key in keys:
|
if key.startswith("model.diffusion_model"):
|
||||||
if key.startswith("model.diffusion_model"):
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
else:
|
||||||
else:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
print(
|
print(
|
||||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||||
)
|
)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith(unet_key):
|
if key.startswith(unet_key):
|
||||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||||
|
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue