diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index ffcea833..5ca98469 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -8,7 +8,7 @@ import re import torch -from safetensors.torch import save_file +from safetensors.torch import load_file, save_file # =================# @@ -278,23 +278,38 @@ if __name__ == "__main__": assert args.checkpoint_path is not None, "Must provide a checkpoint path!" - unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") - vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") - text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") + # Path for safetensors + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") # Convert the UNet model - unet_state_dict = torch.load(unet_path, map_location="cpu") unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} # Convert the VAE model - vae_state_dict = torch.load(vae_path, map_location="cpu") vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - # Convert the text encoder model - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict