Handle missing global_step key in scripts/convert_original_stable_diffusion_to_diffusers.py (#1612)
handle missing global_step key and don't download config if it already exists
This commit is contained in:
parent
ded3299d68
commit
d2dc4de303
|
@ -854,7 +854,13 @@ if __name__ == "__main__":
|
||||||
prediction_type = args.prediction_type
|
prediction_type = args.prediction_type
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint_path)
|
checkpoint = torch.load(args.checkpoint_path)
|
||||||
global_step = checkpoint["global_step"]
|
|
||||||
|
# Sometimes models don't have the global_step item
|
||||||
|
if "global_step" in checkpoint:
|
||||||
|
global_step = checkpoint["global_step"]
|
||||||
|
else:
|
||||||
|
print("global_step key not found in model")
|
||||||
|
global_step = None
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
upcast_attention = False
|
upcast_attention = False
|
||||||
|
@ -862,20 +868,24 @@ if __name__ == "__main__":
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
|
||||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||||
# model_type = "v2"
|
if not os.path.isfile("v2-inference-v.yaml"):
|
||||||
os.system(
|
# model_type = "v2"
|
||||||
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
os.system(
|
||||||
)
|
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
|
" -O v2-inference-v.yaml"
|
||||||
|
)
|
||||||
args.original_config_file = "./v2-inference-v.yaml"
|
args.original_config_file = "./v2-inference-v.yaml"
|
||||||
|
|
||||||
if global_step == 110000:
|
if global_step == 110000:
|
||||||
# v2.1 needs to upcast attention
|
# v2.1 needs to upcast attention
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
else:
|
else:
|
||||||
# model_type = "v1"
|
if not os.path.isfile("v1-inference.yaml"):
|
||||||
os.system(
|
# model_type = "v1"
|
||||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
os.system(
|
||||||
)
|
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||||
|
" -O v1-inference.yaml"
|
||||||
|
)
|
||||||
args.original_config_file = "./v1-inference.yaml"
|
args.original_config_file = "./v1-inference.yaml"
|
||||||
|
|
||||||
original_config = OmegaConf.load(args.original_config_file)
|
original_config = OmegaConf.load(args.original_config_file)
|
||||||
|
|
Loading…
Reference in New Issue