diff --git a/modules/sd_models.py b/modules/sd_models.py index ae36841af..772364808 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -5,6 +5,7 @@ import gc from collections import namedtuple import torch import re +import safetensors.torch from omegaconf import OmegaConf from ldm.util import instantiate_from_config @@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - if checkpoint_file.endswith(".safetensors"): - try: - from safetensors.torch import load_file - except ImportError as e: - raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") - pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location) else: pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") diff --git a/requirements.txt b/requirements.txt index e4e5ec642..5f3d96232 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,4 @@ lark inflection GitPython torchsde +safetensors diff --git a/requirements_versions.txt b/requirements_versions.txt index 8d557fe38..035fa82f4 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -26,3 +26,4 @@ lark==1.1.2 inflection==0.5.1 GitPython==3.1.27 torchsde==0.2.5 +safetensors==0.2.5