fix conv
This commit is contained in:
parent
d6ef90bafc
commit
d493d16504
|
@ -23,6 +23,7 @@
|
||||||
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -288,17 +289,38 @@ def convert(model_path: str, checkpoint_path: str, half: bool):
|
||||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||||
|
|
||||||
# Convert the UNet model
|
# Convert the UNet model
|
||||||
|
if osp.exists(unet_path):
|
||||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||||
|
else:
|
||||||
|
unet_state_dict = {}
|
||||||
|
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||||
|
with safe_open(unet_path, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
unet_state_dict[key] = f.get_tensor(key)
|
||||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||||
|
|
||||||
# Convert the VAE model
|
# Convert the VAE model
|
||||||
|
if osp.exists(vae_path):
|
||||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||||
|
else:
|
||||||
|
vae_state_dict = {}
|
||||||
|
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||||
|
with safe_open(vae_path, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
vae_state_dict[key] = f.get_tensor(key)
|
||||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||||
|
|
||||||
# Convert the text encoder model
|
# Convert the text encoder model
|
||||||
|
if osp.exists(text_enc_path):
|
||||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||||
|
else:
|
||||||
|
text_enc_dict = {}
|
||||||
|
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||||
|
with safe_open(text_enc_path, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
text_enc_dict[key] = f.get_tensor(key)
|
||||||
|
|
||||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||||
|
|
Loading…
Reference in New Issue