save as safetensors not ckpt
This commit is contained in:
parent
6ea721887c
commit
299af88f22
|
@ -24,6 +24,7 @@
|
|||
import os.path as osp
|
||||
import re
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -338,6 +339,4 @@ def convert(model_path: str, checkpoint_path: str, half: bool):
|
|||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
if half:
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()}
|
||||
state_dict = {"state_dict": state_dict}
|
||||
torch.save(state_dict, checkpoint_path)
|
||||
|
||||
save_file(state_dict, checkpoint_path.replace(".ckpt", ".safetensors"))
|
||||
|
|
Loading…
Reference in New Issue