fix reshaping for picking apps
This commit is contained in:
parent
a05ffca82e
commit
1236329677
|
@ -2,11 +2,35 @@ import os
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import save_file, load_file
|
from safetensors.torch import save_file, load_file
|
||||||
|
|
||||||
def fix_vae_keys(state_dict):
|
reshapes = ["first_stage_model.decoder.mid.attn_1.to_k.weight",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_q.weight",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_v.weight",
|
||||||
|
"first_stage_model.encoder.mid.attn_1.to_k.weight",
|
||||||
|
"first_stage_model.encoder.mid.attn_1.to_q.weight",
|
||||||
|
"first_stage_model.encoder.mid.attn_1.to_v.weight",
|
||||||
|
"first_stage_model.decoder.mid.attn_1.to_out.0.weight",
|
||||||
|
"first_stage_model.encoder.mid.attn_1.to_out.0.weight"
|
||||||
|
]
|
||||||
|
|
||||||
|
def _reshape(state_dict, key):
|
||||||
|
if key in reshapes:
|
||||||
|
if state_dict[key].dim() == 2:
|
||||||
|
old_shape = state_dict[key].shape
|
||||||
|
# add two dimensions after last dim
|
||||||
|
state_dict[key] = state_dict[key].unsqueeze(-1).unsqueeze(-1)
|
||||||
|
print(f" ** reshaped {key} from {old_shape} to {state_dict[key].shape}")
|
||||||
|
else:
|
||||||
|
print(f" ** skipping {key} because it is already correct shape {state_dict[key].shape}")
|
||||||
|
|
||||||
|
def fix_vae_keys(state_dict, is_sd1=True):
|
||||||
|
if not is_sd1:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
|
|
||||||
with open("backdate_vae_keys.log", "w") as f:
|
with open("backdate_vae_keys.log", "w") as f:
|
||||||
f.write(f"keys:\n")
|
f.write(f"keys:\n")
|
||||||
|
changed_i = 0
|
||||||
|
|
||||||
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' not in state_dict:
|
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' not in state_dict:
|
||||||
# openai clip-l for some reason has this defined as part of its state_dict, which is dumb, but whatever
|
# openai clip-l for some reason has this defined as part of its state_dict, which is dumb, but whatever
|
||||||
|
@ -14,33 +38,33 @@ def fix_vae_keys(state_dict):
|
||||||
|
|
||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
new_key = key
|
new_key = key
|
||||||
|
_reshape(state_dict, key)
|
||||||
if key.startswith("first_stage_model"):
|
if key.startswith("first_stage_model"):
|
||||||
|
|
||||||
if ".to_q" in key:
|
if ".to_q" in key:
|
||||||
print(f" * backdating {key}")
|
print(f" * backdating {key} {state_dict[key].shape}")
|
||||||
new_key = new_key.replace('.to_q.', '.q.')
|
new_key = new_key.replace('.to_q.', '.q.')
|
||||||
print(f" ** new key -> {new_key}\n")
|
print(f" ** new key -> {new_key}\n")
|
||||||
elif ".to_k" in key:
|
elif ".to_k" in key:
|
||||||
print(f" * backdating {key}")
|
print(f" * backdating {key} {state_dict[key].shape}")
|
||||||
new_key = new_key.replace('.to_k.', '.k.')
|
new_key = new_key.replace('.to_k.', '.k.')
|
||||||
print(f" ** new key -> {new_key}\n")
|
print(f" ** new key -> {new_key}\n")
|
||||||
elif ".to_v" in key:
|
elif ".to_v" in key:
|
||||||
print(f" * backdating {key}")
|
print(f" * backdating {key} {state_dict[key].shape}")
|
||||||
new_key = new_key.replace('.to_v.', '.v.')
|
new_key = new_key.replace('.to_v.', '.v.')
|
||||||
print(f" ** new key -> {new_key}\n")
|
print(f" ** new key -> {new_key}\n")
|
||||||
elif ".to_out.0" in key:
|
elif ".to_out.0" in key:
|
||||||
print(f" * backdating {key}")
|
print(f" * backdating {key} {state_dict[key].shape}")
|
||||||
new_key = new_key.replace('.to_out.0', '.proj_out')
|
new_key = new_key.replace('.to_out.0', '.proj_out')
|
||||||
print(f" ** new key -> {new_key}\n")
|
print(f" ** new key -> {new_key} {state_dict[key].shape}\n")
|
||||||
|
|
||||||
new_state_dict[new_key] = state_dict[key]
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
|
||||||
changed = key != new_key
|
changed = 1 if key != new_key else 0
|
||||||
changed = 1 if changed else 0
|
f.write(f"{changed}: {key} -- {new_key} {new_state_dict[new_key].shape}\n")
|
||||||
f.write(f"{changed}: {key} -- {new_key}\n")
|
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def _backdate_keys(filepath, state_dict):
|
def _backdate_keys(filepath, state_dict):
|
||||||
new_state_dict = fix_vae_keys(state_dict)
|
new_state_dict = fix_vae_keys(state_dict)
|
||||||
base_path_without_ext = os.path.splitext(filepath)[0]
|
base_path_without_ext = os.path.splitext(filepath)[0]
|
||||||
|
@ -66,8 +90,8 @@ def _compare_keys(filea_state_dict, fileb_state_dict):
|
||||||
for filea_key, fileb_key in zip(filea_state_dict_keys, fileb_state_dict_keys):
|
for filea_key, fileb_key in zip(filea_state_dict_keys, fileb_state_dict_keys):
|
||||||
if filea_key != fileb_key:
|
if filea_key != fileb_key:
|
||||||
print("Mismatched keys:")
|
print("Mismatched keys:")
|
||||||
print (f" filea key: {filea_key}")
|
print (f" filea key: {filea_key} {filea_state_dict[filea_key].shape}")
|
||||||
print (f" fileb key: {fileb_key}")
|
print (f" fileb key: {fileb_key} {fileb_state_dict[fileb_key].shape}")
|
||||||
else:
|
else:
|
||||||
#print (f"{ckpt_key} == {st_key}")
|
#print (f"{ckpt_key} == {st_key}")
|
||||||
pass
|
pass
|
||||||
|
@ -82,6 +106,12 @@ def _load(filepath):
|
||||||
state_dict = torch.load(filepath, map_location='cpu')['state_dict']
|
state_dict = torch.load(filepath, map_location='cpu')['state_dict']
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def _dump_keys(filepath, state_dict):
|
||||||
|
with open(filepath, "w") as f:
|
||||||
|
state_dict_keys = sorted(state_dict.keys())
|
||||||
|
for key in state_dict_keys:
|
||||||
|
f.write(f"{key} - {state_dict[key].shape}\n")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("BACKDATE AutoencoderKL/VAE KEYS TO OLD NAMES SCRIPT OF DOOM")
|
print("BACKDATE AutoencoderKL/VAE KEYS TO OLD NAMES SCRIPT OF DOOM")
|
||||||
print("================================")
|
print("================================")
|
||||||
|
@ -89,9 +119,11 @@ if __name__ == "__main__":
|
||||||
print(" --fileb <path to ckpt or safetensors file> to compare keys to filea")
|
print(" --fileb <path to ckpt or safetensors file> to compare keys to filea")
|
||||||
print(" --compare to run keys comparison (requires both --filea and --fileb)")
|
print(" --compare to run keys comparison (requires both --filea and --fileb)")
|
||||||
print(" --backdate to backdate keys (only for --filea)")
|
print(" --backdate to backdate keys (only for --filea)")
|
||||||
print(" You must specify either --compare or --backdate to do anything")
|
print(" --dumpkeys to write key and shapes for either or both files keys for files to '<filename>.txt'")
|
||||||
|
print(" You must specify one of --compare or --backdate or --dumpkeys to do anything.")
|
||||||
print(" ex. python utils/backdate_vae_keys.py --filea my_finetune.safetensors --fileb original_sd15.ckpt --compare")
|
print(" ex. python utils/backdate_vae_keys.py --filea my_finetune.safetensors --fileb original_sd15.ckpt --compare")
|
||||||
print(" ex. python utils/backdate_vae_keys.py --filea my_finetune.safetensors --backdate")
|
print(" ex. python utils/backdate_vae_keys.py --filea my_finetune.safetensors --backdate")
|
||||||
|
print(" ex. python utils/backdate_vae_keys.py --filea what_is_this_model_shape.safetensors --dumpkeys")
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -99,11 +131,18 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--fileb", type=str, required=False, help="Path to the safetensors file to fix")
|
parser.add_argument("--fileb", type=str, required=False, help="Path to the safetensors file to fix")
|
||||||
parser.add_argument("--compare", action="store_true", help="Compare keys")
|
parser.add_argument("--compare", action="store_true", help="Compare keys")
|
||||||
parser.add_argument("--backdate", action="store_true", help="backdates the keys in filea only")
|
parser.add_argument("--backdate", action="store_true", help="backdates the keys in filea only")
|
||||||
|
parser.add_argument("--dumpkeys", action="store_true", help="dump keys to txt file")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
filea_state_dict = _load(args.filea) if args.filea else None
|
filea_state_dict = _load(args.filea) if args.filea else None
|
||||||
fileb_state_dict = _load(args.fileb) if args.fileb else None
|
fileb_state_dict = _load(args.fileb) if args.fileb else None
|
||||||
|
|
||||||
|
if args.dumpkeys:
|
||||||
|
print(f"Dumping keys to txt files")
|
||||||
|
if args.filea:
|
||||||
|
_dump_keys(f"{os.path.splitext(args.filea)[0]}.txt", filea_state_dict)
|
||||||
|
if args.fileb:
|
||||||
|
_dump_keys(f"{os.path.splitext(args.fileb)[0]}.txt", fileb_state_dict)
|
||||||
if args.compare and not args.backdate:
|
if args.compare and not args.backdate:
|
||||||
print(f"Comparing keys in {args.filea} to {args.fileb}")
|
print(f"Comparing keys in {args.filea} to {args.fileb}")
|
||||||
_compare_keys(filea_state_dict, fileb_state_dict)
|
_compare_keys(filea_state_dict, fileb_state_dict)
|
||||||
|
|
Loading…
Reference in New Issue