From d493d16504eef758712aa7a1207fdfbf97bf9742 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sun, 3 Sep 2023 16:03:45 -0400 Subject: [PATCH] fix conv --- utils/convert_diff_to_ckpt.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/utils/convert_diff_to_ckpt.py b/utils/convert_diff_to_ckpt.py index 5df3d3f..9f00fe4 100644 --- a/utils/convert_diff_to_ckpt.py +++ b/utils/convert_diff_to_ckpt.py @@ -23,6 +23,7 @@ import os.path as osp import re +from safetensors import safe_open import torch @@ -288,17 +289,38 @@ def convert(model_path: str, checkpoint_path: str, half: bool): text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") # Convert the UNet model - unet_state_dict = torch.load(unet_path, map_location="cpu") + if osp.exists(unet_path): + unet_state_dict = torch.load(unet_path, map_location="cpu") + else: + unet_state_dict = {} + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + with safe_open(unet_path, framework="pt", device="cpu") as f: + for key in f.keys(): + unet_state_dict[key] = f.get_tensor(key) unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} # Convert the VAE model - vae_state_dict = torch.load(vae_path, map_location="cpu") + if osp.exists(vae_path): + vae_state_dict = torch.load(vae_path, map_location="cpu") + else: + vae_state_dict = {} + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + with safe_open(vae_path, framework="pt", device="cpu") as f: + for key in f.keys(): + vae_state_dict[key] = f.get_tensor(key) vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} # Convert the text encoder model - text_enc_dict = torch.load(text_enc_path, map_location="cpu") + if osp.exists(text_enc_path): + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + else: + text_enc_dict = {} + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + with safe_open(text_enc_path, framework="pt", device="cpu") as f: + for key in f.keys(): + text_enc_dict[key] = f.get_tensor(key) # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict