fix conv
This commit is contained in:
parent
d6ef90bafc
commit
d493d16504
|
@ -23,6 +23,7 @@
|
|||
|
||||
import os.path as osp
|
||||
import re
|
||||
from safetensors import safe_open
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -288,17 +289,38 @@ def convert(model_path: str, checkpoint_path: str, half: bool):
|
|||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
else:
|
||||
unet_state_dict = {}
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
with safe_open(unet_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
unet_state_dict[key] = f.get_tensor(key)
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
else:
|
||||
vae_state_dict = {}
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
with safe_open(vae_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
vae_state_dict[key] = f.get_tensor(key)
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Convert the text encoder model
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
else:
|
||||
text_enc_dict = {}
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
with safe_open(text_enc_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
text_enc_dict[key] = f.get_tensor(key)
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
|
Loading…
Reference in New Issue