CompVis -> diffusers script - allow converting from merged checkpoint to either EMA or non-EMA (#991)
* improve script * up
This commit is contained in:
parent
0343d8f531
commit
d9cfe325a5
|
@ -285,15 +285,34 @@ def create_ldm_bert_config(original_config):
|
|||
return config
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config):
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
unet_key = "model.diffusion_model."
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
unet_key = "model.diffusion_model."
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
if extract_ema:
|
||||
print(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
print(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
@ -630,6 +649,15 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract_ema",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
||||
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
||||
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -641,7 +669,9 @@ if __name__ == "__main__":
|
|||
args.original_config_file = "./v1-inference.yaml"
|
||||
|
||||
original_config = OmegaConf.load(args.original_config_file)
|
||||
checkpoint = torch.load(args.checkpoint_path)["state_dict"]
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
|
@ -669,7 +699,9 @@ if __name__ == "__main__":
|
|||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config)
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
|
Loading…
Reference in New Issue