add positions id back for clip-l
This commit is contained in:
parent
49e382c82b
commit
a05ffca82e
|
@ -5,27 +5,38 @@ from safetensors.torch import save_file, load_file
|
||||||
def fix_vae_keys(state_dict):
|
def fix_vae_keys(state_dict):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
|
|
||||||
for key in state_dict.keys():
|
with open("backdate_vae_keys.log", "w") as f:
|
||||||
new_key = key
|
f.write(f"keys:\n")
|
||||||
if key.startswith("first_stage_model"):
|
|
||||||
if ".to_q" in key:
|
|
||||||
print(f" * backdating {key}")
|
|
||||||
new_key = new_key.replace('.to_q.', '.q.')
|
|
||||||
print(f" ** new key -> {new_key}\n")
|
|
||||||
elif ".to_k" in key:
|
|
||||||
print(f" * backdating {key}")
|
|
||||||
new_key = new_key.replace('.to_k.', '.k.')
|
|
||||||
print(f" ** new key -> {new_key}\n")
|
|
||||||
elif ".to_v" in key:
|
|
||||||
print(f" * backdating {key}")
|
|
||||||
new_key = new_key.replace('.to_v.', '.v.')
|
|
||||||
print(f" ** new key -> {new_key}\n")
|
|
||||||
elif ".to_out.0" in key:
|
|
||||||
print(f" * backdating {key}")
|
|
||||||
new_key = new_key.replace('.to_out.0', '.proj_out')
|
|
||||||
print(f" ** new key -> {new_key}\n")
|
|
||||||
|
|
||||||
new_state_dict[new_key] = state_dict[key]
|
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' not in state_dict:
|
||||||
|
# openai clip-l for some reason has this defined as part of its state_dict, which is dumb, but whatever
|
||||||
|
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = torch.linspace(0, 76, 77, dtype=torch.int64).unsqueeze(0)
|
||||||
|
|
||||||
|
for key in state_dict.keys():
|
||||||
|
new_key = key
|
||||||
|
if key.startswith("first_stage_model"):
|
||||||
|
if ".to_q" in key:
|
||||||
|
print(f" * backdating {key}")
|
||||||
|
new_key = new_key.replace('.to_q.', '.q.')
|
||||||
|
print(f" ** new key -> {new_key}\n")
|
||||||
|
elif ".to_k" in key:
|
||||||
|
print(f" * backdating {key}")
|
||||||
|
new_key = new_key.replace('.to_k.', '.k.')
|
||||||
|
print(f" ** new key -> {new_key}\n")
|
||||||
|
elif ".to_v" in key:
|
||||||
|
print(f" * backdating {key}")
|
||||||
|
new_key = new_key.replace('.to_v.', '.v.')
|
||||||
|
print(f" ** new key -> {new_key}\n")
|
||||||
|
elif ".to_out.0" in key:
|
||||||
|
print(f" * backdating {key}")
|
||||||
|
new_key = new_key.replace('.to_out.0', '.proj_out')
|
||||||
|
print(f" ** new key -> {new_key}\n")
|
||||||
|
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
|
||||||
|
changed = key != new_key
|
||||||
|
changed = 1 if changed else 0
|
||||||
|
f.write(f"{changed}: {key} -- {new_key}\n")
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue