This commit is contained in:
Victor Hall 2023-09-03 16:03:45 -04:00
parent d6ef90bafc
commit d493d16504
1 changed files with 25 additions and 3 deletions

View File

@ -23,6 +23,7 @@
import os.path as osp
import re
from safetensors import safe_open
import torch
@ -288,17 +289,38 @@ def convert(model_path: str, checkpoint_path: str, half: bool):
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
# Convert the UNet model
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(unet_path):
unet_state_dict = torch.load(unet_path, map_location="cpu")
else:
unet_state_dict = {}
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
with safe_open(unet_path, framework="pt", device="cpu") as f:
for key in f.keys():
unet_state_dict[key] = f.get_tensor(key)
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = torch.load(vae_path, map_location="cpu")
else:
vae_state_dict = {}
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
with safe_open(vae_path, framework="pt", device="cpu") as f:
for key in f.keys():
vae_state_dict[key] = f.get_tensor(key)
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Convert the text encoder model
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
else:
text_enc_dict = {}
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
with safe_open(text_enc_path, framework="pt", device="cpu") as f:
for key in f.keys():
text_enc_dict[key] = f.get_tensor(key)
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict