diff --git a/utils/backdate_vae_keys.py b/utils/backdate_vae_keys.py index e1ad098..1912f1b 100644 --- a/utils/backdate_vae_keys.py +++ b/utils/backdate_vae_keys.py @@ -2,11 +2,35 @@ import os import torch from safetensors.torch import save_file, load_file -def fix_vae_keys(state_dict): +reshapes = ["first_stage_model.decoder.mid.attn_1.to_k.weight", + "first_stage_model.decoder.mid.attn_1.to_q.weight", + "first_stage_model.decoder.mid.attn_1.to_v.weight", + "first_stage_model.encoder.mid.attn_1.to_k.weight", + "first_stage_model.encoder.mid.attn_1.to_q.weight", + "first_stage_model.encoder.mid.attn_1.to_v.weight", + "first_stage_model.decoder.mid.attn_1.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.to_out.0.weight" + ] + +def _reshape(state_dict, key): + if key in reshapes: + if state_dict[key].dim() == 2: + old_shape = state_dict[key].shape + # add two dimensions after last dim + state_dict[key] = state_dict[key].unsqueeze(-1).unsqueeze(-1) + print(f" ** reshaped {key} from {old_shape} to {state_dict[key].shape}") + else: + print(f" ** skipping {key} because it is already correct shape {state_dict[key].shape}") + +def fix_vae_keys(state_dict, is_sd1=True): + if not is_sd1: + return state_dict + new_state_dict = {} with open("backdate_vae_keys.log", "w") as f: f.write(f"keys:\n") + changed_i = 0 if 'cond_stage_model.transformer.text_model.embeddings.position_ids' not in state_dict: # openai clip-l for some reason has this defined as part of its state_dict, which is dumb, but whatever @@ -14,33 +38,33 @@ def fix_vae_keys(state_dict): for key in state_dict.keys(): new_key = key + _reshape(state_dict, key) if key.startswith("first_stage_model"): + if ".to_q" in key: - print(f" * backdating {key}") + print(f" * backdating {key} {state_dict[key].shape}") new_key = new_key.replace('.to_q.', '.q.') print(f" ** new key -> {new_key}\n") elif ".to_k" in key: - print(f" * backdating {key}") + print(f" * backdating {key} {state_dict[key].shape}") new_key = new_key.replace('.to_k.', '.k.') print(f" ** new key -> {new_key}\n") elif ".to_v" in key: - print(f" * backdating {key}") + print(f" * backdating {key} {state_dict[key].shape}") new_key = new_key.replace('.to_v.', '.v.') print(f" ** new key -> {new_key}\n") elif ".to_out.0" in key: - print(f" * backdating {key}") + print(f" * backdating {key} {state_dict[key].shape}") new_key = new_key.replace('.to_out.0', '.proj_out') - print(f" ** new key -> {new_key}\n") + print(f" ** new key -> {new_key} {state_dict[key].shape}\n") new_state_dict[new_key] = state_dict[key] - changed = key != new_key - changed = 1 if changed else 0 - f.write(f"{changed}: {key} -- {new_key}\n") + changed = 1 if key != new_key else 0 + f.write(f"{changed}: {key} -- {new_key} {new_state_dict[new_key].shape}\n") return new_state_dict - def _backdate_keys(filepath, state_dict): new_state_dict = fix_vae_keys(state_dict) base_path_without_ext = os.path.splitext(filepath)[0] @@ -66,8 +90,8 @@ def _compare_keys(filea_state_dict, fileb_state_dict): for filea_key, fileb_key in zip(filea_state_dict_keys, fileb_state_dict_keys): if filea_key != fileb_key: print("Mismatched keys:") - print (f" filea key: {filea_key}") - print (f" fileb key: {fileb_key}") + print (f" filea key: {filea_key} {filea_state_dict[filea_key].shape}") + print (f" fileb key: {fileb_key} {fileb_state_dict[fileb_key].shape}") else: #print (f"{ckpt_key} == {st_key}") pass @@ -82,6 +106,12 @@ def _load(filepath): state_dict = torch.load(filepath, map_location='cpu')['state_dict'] return state_dict +def _dump_keys(filepath, state_dict): + with open(filepath, "w") as f: + state_dict_keys = sorted(state_dict.keys()) + for key in state_dict_keys: + f.write(f"{key} - {state_dict[key].shape}\n") + if __name__ == "__main__": print("BACKDATE AutoencoderKL/VAE KEYS TO OLD NAMES SCRIPT OF DOOM") print("================================") @@ -89,9 +119,11 @@ if __name__ == "__main__": print(" --fileb to compare keys to filea") print(" --compare to run keys comparison (requires both --filea and --fileb)") print(" --backdate to backdate keys (only for --filea)") - print(" You must specify either --compare or --backdate to do anything") + print(" --dumpkeys to write key and shapes for either or both files keys for files to '.txt'") + print(" You must specify one of --compare or --backdate or --dumpkeys to do anything.") print(" ex. python utils/backdate_vae_keys.py --filea my_finetune.safetensors --fileb original_sd15.ckpt --compare") print(" ex. python utils/backdate_vae_keys.py --filea my_finetune.safetensors --backdate") + print(" ex. python utils/backdate_vae_keys.py --filea what_is_this_model_shape.safetensors --dumpkeys") import argparse parser = argparse.ArgumentParser() @@ -99,11 +131,18 @@ if __name__ == "__main__": parser.add_argument("--fileb", type=str, required=False, help="Path to the safetensors file to fix") parser.add_argument("--compare", action="store_true", help="Compare keys") parser.add_argument("--backdate", action="store_true", help="backdates the keys in filea only") + parser.add_argument("--dumpkeys", action="store_true", help="dump keys to txt file") args = parser.parse_args() filea_state_dict = _load(args.filea) if args.filea else None fileb_state_dict = _load(args.fileb) if args.fileb else None + if args.dumpkeys: + print(f"Dumping keys to txt files") + if args.filea: + _dump_keys(f"{os.path.splitext(args.filea)[0]}.txt", filea_state_dict) + if args.fileb: + _dump_keys(f"{os.path.splitext(args.fileb)[0]}.txt", fileb_state_dict) if args.compare and not args.backdate: print(f"Comparing keys in {args.filea} to {args.fileb}") _compare_keys(filea_state_dict, fileb_state_dict)