add positions id back for clip-l

This commit is contained in:
Victor Hall 2023-11-29 20:25:38 -05:00
parent 49e382c82b
commit a05ffca82e
1 changed files with 31 additions and 20 deletions

View File

@ -5,27 +5,38 @@ from safetensors.torch import save_file, load_file
def fix_vae_keys(state_dict):
new_state_dict = {}
for key in state_dict.keys():
new_key = key
if key.startswith("first_stage_model"):
if ".to_q" in key:
print(f" * backdating {key}")
new_key = new_key.replace('.to_q.', '.q.')
print(f" ** new key -> {new_key}\n")
elif ".to_k" in key:
print(f" * backdating {key}")
new_key = new_key.replace('.to_k.', '.k.')
print(f" ** new key -> {new_key}\n")
elif ".to_v" in key:
print(f" * backdating {key}")
new_key = new_key.replace('.to_v.', '.v.')
print(f" ** new key -> {new_key}\n")
elif ".to_out.0" in key:
print(f" * backdating {key}")
new_key = new_key.replace('.to_out.0', '.proj_out')
print(f" ** new key -> {new_key}\n")
with open("backdate_vae_keys.log", "w") as f:
f.write(f"keys:\n")
new_state_dict[new_key] = state_dict[key]
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' not in state_dict:
# openai clip-l for some reason has this defined as part of its state_dict, which is dumb, but whatever
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = torch.linspace(0, 76, 77, dtype=torch.int64).unsqueeze(0)
for key in state_dict.keys():
new_key = key
if key.startswith("first_stage_model"):
if ".to_q" in key:
print(f" * backdating {key}")
new_key = new_key.replace('.to_q.', '.q.')
print(f" ** new key -> {new_key}\n")
elif ".to_k" in key:
print(f" * backdating {key}")
new_key = new_key.replace('.to_k.', '.k.')
print(f" ** new key -> {new_key}\n")
elif ".to_v" in key:
print(f" * backdating {key}")
new_key = new_key.replace('.to_v.', '.v.')
print(f" ** new key -> {new_key}\n")
elif ".to_out.0" in key:
print(f" * backdating {key}")
new_key = new_key.replace('.to_out.0', '.proj_out')
print(f" ** new key -> {new_key}\n")
new_state_dict[new_key] = state_dict[key]
changed = key != new_key
changed = 1 if changed else 0
f.write(f"{changed}: {key} -- {new_key}\n")
return new_state_dict