From a05ffca82e371fbbbcdcfb067b961d22597e066f Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Wed, 29 Nov 2023 20:25:38 -0500 Subject: [PATCH] add positions id back for clip-l --- utils/backdate_vae_keys.py | 51 +++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/utils/backdate_vae_keys.py b/utils/backdate_vae_keys.py index a23e145..e1ad098 100644 --- a/utils/backdate_vae_keys.py +++ b/utils/backdate_vae_keys.py @@ -5,27 +5,38 @@ from safetensors.torch import save_file, load_file def fix_vae_keys(state_dict): new_state_dict = {} - for key in state_dict.keys(): - new_key = key - if key.startswith("first_stage_model"): - if ".to_q" in key: - print(f" * backdating {key}") - new_key = new_key.replace('.to_q.', '.q.') - print(f" ** new key -> {new_key}\n") - elif ".to_k" in key: - print(f" * backdating {key}") - new_key = new_key.replace('.to_k.', '.k.') - print(f" ** new key -> {new_key}\n") - elif ".to_v" in key: - print(f" * backdating {key}") - new_key = new_key.replace('.to_v.', '.v.') - print(f" ** new key -> {new_key}\n") - elif ".to_out.0" in key: - print(f" * backdating {key}") - new_key = new_key.replace('.to_out.0', '.proj_out') - print(f" ** new key -> {new_key}\n") + with open("backdate_vae_keys.log", "w") as f: + f.write(f"keys:\n") - new_state_dict[new_key] = state_dict[key] + if 'cond_stage_model.transformer.text_model.embeddings.position_ids' not in state_dict: + # openai clip-l for some reason has this defined as part of its state_dict, which is dumb, but whatever + state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = torch.linspace(0, 76, 77, dtype=torch.int64).unsqueeze(0) + + for key in state_dict.keys(): + new_key = key + if key.startswith("first_stage_model"): + if ".to_q" in key: + print(f" * backdating {key}") + new_key = new_key.replace('.to_q.', '.q.') + print(f" ** new key -> {new_key}\n") + elif ".to_k" in key: + print(f" * backdating {key}") + new_key = new_key.replace('.to_k.', '.k.') + print(f" ** new key -> {new_key}\n") + elif ".to_v" in key: + print(f" * backdating {key}") + new_key = new_key.replace('.to_v.', '.v.') + print(f" ** new key -> {new_key}\n") + elif ".to_out.0" in key: + print(f" * backdating {key}") + new_key = new_key.replace('.to_out.0', '.proj_out') + print(f" ** new key -> {new_key}\n") + + new_state_dict[new_key] = state_dict[key] + + changed = key != new_key + changed = 1 if changed else 0 + f.write(f"{changed}: {key} -- {new_key}\n") return new_state_dict