diff --git a/utils/convert_diff_to_ckpt.py b/utils/convert_diff_to_ckpt.py index 9f00fe4..5ef1ccb 100644 --- a/utils/convert_diff_to_ckpt.py +++ b/utils/convert_diff_to_ckpt.py @@ -24,6 +24,7 @@ import os.path as osp import re from safetensors import safe_open +from safetensors.torch import save_file import torch @@ -338,6 +339,4 @@ def convert(model_path: str, checkpoint_path: str, half: bool): state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} if half: state_dict = {k: v.half() for k, v in state_dict.items()} - state_dict = {"state_dict": state_dict} - torch.save(state_dict, checkpoint_path) - \ No newline at end of file + save_file(state_dict, checkpoint_path.replace(".ckpt", ".safetensors"))