Handle missing global_step key in scripts/convert_original_stable_diffusion_to_diffusers.py (#1612)
handle missing global_step key and don't download config if it already exists
This commit is contained in:
parent
ded3299d68
commit
d2dc4de303
|
@ -854,7 +854,13 @@ if __name__ == "__main__":
|
|||
prediction_type = args.prediction_type
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
global_step = checkpoint["global_step"]
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print("global_step key not found in model")
|
||||
global_step = None
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
upcast_attention = False
|
||||
|
@ -862,20 +868,24 @@ if __name__ == "__main__":
|
|||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
if not os.path.isfile("v2-inference-v.yaml"):
|
||||
# model_type = "v2"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
" -O v2-inference-v.yaml"
|
||||
)
|
||||
args.original_config_file = "./v2-inference-v.yaml"
|
||||
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
else:
|
||||
# model_type = "v1"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
if not os.path.isfile("v1-inference.yaml"):
|
||||
# model_type = "v1"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
" -O v1-inference.yaml"
|
||||
)
|
||||
args.original_config_file = "./v1-inference.yaml"
|
||||
|
||||
original_config = OmegaConf.load(args.original_config_file)
|
||||
|
|
Loading…
Reference in New Issue