CompVis -> diffusers script - allow converting from merged checkpoint to either EMA or non-EMA (#991)
* improve script * up
This commit is contained in:
parent
0343d8f531
commit
d9cfe325a5
|
@ -285,15 +285,34 @@ def create_ldm_bert_config(original_config):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_unet_checkpoint(checkpoint, config):
|
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
||||||
"""
|
"""
|
||||||
Takes a state dict and a config, and returns a converted checkpoint.
|
Takes a state dict and a config, and returns a converted checkpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# extract state_dict for UNet
|
# extract state_dict for UNet
|
||||||
unet_state_dict = {}
|
unet_state_dict = {}
|
||||||
unet_key = "model.diffusion_model."
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
|
unet_key = "model.diffusion_model."
|
||||||
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
|
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
|
if extract_ema:
|
||||||
|
print(
|
||||||
|
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||||
|
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||||
|
)
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith("model.diffusion_model"):
|
||||||
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||||
|
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||||
|
)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith(unet_key):
|
if key.startswith(unet_key):
|
||||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||||
|
@ -630,6 +649,15 @@ if __name__ == "__main__":
|
||||||
type=str,
|
type=str,
|
||||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
|
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--extract_ema",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
||||||
|
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
||||||
|
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -641,7 +669,9 @@ if __name__ == "__main__":
|
||||||
args.original_config_file = "./v1-inference.yaml"
|
args.original_config_file = "./v1-inference.yaml"
|
||||||
|
|
||||||
original_config = OmegaConf.load(args.original_config_file)
|
original_config = OmegaConf.load(args.original_config_file)
|
||||||
checkpoint = torch.load(args.checkpoint_path)["state_dict"]
|
|
||||||
|
checkpoint = torch.load(args.checkpoint_path)
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
num_train_timesteps = original_config.model.params.timesteps
|
num_train_timesteps = original_config.model.params.timesteps
|
||||||
beta_start = original_config.model.params.linear_start
|
beta_start = original_config.model.params.linear_start
|
||||||
|
@ -669,7 +699,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Convert the UNet2DConditionModel model.
|
# Convert the UNet2DConditionModel model.
|
||||||
unet_config = create_unet_diffusers_config(original_config)
|
unet_config = create_unet_diffusers_config(original_config)
|
||||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||||
|
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
|
||||||
|
)
|
||||||
|
|
||||||
unet = UNet2DConditionModel(**unet_config)
|
unet = UNet2DConditionModel(**unet_config)
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
|
Loading…
Reference in New Issue