add positions id back for clip-l

This commit is contained in:
Victor Hall 2023-11-29 20:25:38 -05:00
parent 49e382c82b
commit a05ffca82e
1 changed files with 31 additions and 20 deletions

View File

@ -5,6 +5,13 @@ from safetensors.torch import save_file, load_file
def fix_vae_keys(state_dict):
new_state_dict = {}
with open("backdate_vae_keys.log", "w") as f:
f.write(f"keys:\n")
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' not in state_dict:
# openai clip-l for some reason has this defined as part of its state_dict, which is dumb, but whatever
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = torch.linspace(0, 76, 77, dtype=torch.int64).unsqueeze(0)
for key in state_dict.keys():
new_key = key
if key.startswith("first_stage_model"):
@ -27,6 +34,10 @@ def fix_vae_keys(state_dict):
new_state_dict[new_key] = state_dict[key]
changed = key != new_key
changed = 1 if changed else 0
f.write(f"{changed}: {key} -- {new_key}\n")
return new_state_dict