Handle missing global_step key in scripts/convert_original_stable_diffusion_to_diffusers.py (#1612)

handle missing global_step key and don't download config if it already exists
This commit is contained in:
Cyberes 2022-12-12 08:10:52 -07:00 committed by GitHub
parent ded3299d68
commit d2dc4de303
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 9 deletions

View File

@ -854,7 +854,13 @@ if __name__ == "__main__":
prediction_type = args.prediction_type prediction_type = args.prediction_type
checkpoint = torch.load(args.checkpoint_path) checkpoint = torch.load(args.checkpoint_path)
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
global_step = checkpoint["global_step"] global_step = checkpoint["global_step"]
else:
print("global_step key not found in model")
global_step = None
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
upcast_attention = False upcast_attention = False
@ -862,9 +868,11 @@ if __name__ == "__main__":
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not os.path.isfile("v2-inference-v.yaml"):
# model_type = "v2" # model_type = "v2"
os.system( os.system(
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
" -O v2-inference-v.yaml"
) )
args.original_config_file = "./v2-inference-v.yaml" args.original_config_file = "./v2-inference-v.yaml"
@ -872,9 +880,11 @@ if __name__ == "__main__":
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
upcast_attention = True upcast_attention = True
else: else:
if not os.path.isfile("v1-inference.yaml"):
# model_type = "v1" # model_type = "v1"
os.system( os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
" -O v1-inference.yaml"
) )
args.original_config_file = "./v1-inference.yaml" args.original_config_file = "./v1-inference.yaml"