add positions id back for clip-l
This commit is contained in:
parent
49e382c82b
commit
a05ffca82e
|
@ -5,27 +5,38 @@ from safetensors.torch import save_file, load_file
|
|||
def fix_vae_keys(state_dict):
|
||||
new_state_dict = {}
|
||||
|
||||
for key in state_dict.keys():
|
||||
new_key = key
|
||||
if key.startswith("first_stage_model"):
|
||||
if ".to_q" in key:
|
||||
print(f" * backdating {key}")
|
||||
new_key = new_key.replace('.to_q.', '.q.')
|
||||
print(f" ** new key -> {new_key}\n")
|
||||
elif ".to_k" in key:
|
||||
print(f" * backdating {key}")
|
||||
new_key = new_key.replace('.to_k.', '.k.')
|
||||
print(f" ** new key -> {new_key}\n")
|
||||
elif ".to_v" in key:
|
||||
print(f" * backdating {key}")
|
||||
new_key = new_key.replace('.to_v.', '.v.')
|
||||
print(f" ** new key -> {new_key}\n")
|
||||
elif ".to_out.0" in key:
|
||||
print(f" * backdating {key}")
|
||||
new_key = new_key.replace('.to_out.0', '.proj_out')
|
||||
print(f" ** new key -> {new_key}\n")
|
||||
with open("backdate_vae_keys.log", "w") as f:
|
||||
f.write(f"keys:\n")
|
||||
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' not in state_dict:
|
||||
# openai clip-l for some reason has this defined as part of its state_dict, which is dumb, but whatever
|
||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = torch.linspace(0, 76, 77, dtype=torch.int64).unsqueeze(0)
|
||||
|
||||
for key in state_dict.keys():
|
||||
new_key = key
|
||||
if key.startswith("first_stage_model"):
|
||||
if ".to_q" in key:
|
||||
print(f" * backdating {key}")
|
||||
new_key = new_key.replace('.to_q.', '.q.')
|
||||
print(f" ** new key -> {new_key}\n")
|
||||
elif ".to_k" in key:
|
||||
print(f" * backdating {key}")
|
||||
new_key = new_key.replace('.to_k.', '.k.')
|
||||
print(f" ** new key -> {new_key}\n")
|
||||
elif ".to_v" in key:
|
||||
print(f" * backdating {key}")
|
||||
new_key = new_key.replace('.to_v.', '.v.')
|
||||
print(f" ** new key -> {new_key}\n")
|
||||
elif ".to_out.0" in key:
|
||||
print(f" * backdating {key}")
|
||||
new_key = new_key.replace('.to_out.0', '.proj_out')
|
||||
print(f" ** new key -> {new_key}\n")
|
||||
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
changed = key != new_key
|
||||
changed = 1 if changed else 0
|
||||
f.write(f"{changed}: {key} -- {new_key}\n")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue