diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b2e91e3f..eaf6e594 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -19,7 +19,7 @@ import torch from .models.cross_attention import LoRACrossAttnProcessor from .models.modeling_utils import _get_model_file -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging +from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging if is_safetensors_available(): @@ -150,13 +150,14 @@ class UNet2DConditionLoadersMixin: model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): - if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"): - if weight_name is None: - weight_name = LORA_WEIGHT_NAME_SAFE + # Let's first try to load .safetensors weights + if (is_safetensors_available() and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): try: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, - weights_name=weight_name, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, @@ -169,14 +170,13 @@ class UNet2DConditionLoadersMixin: ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except EnvironmentError: - if weight_name == LORA_WEIGHT_NAME_SAFE: - weight_name = None + # try loading non-safetensors weights + pass + if model_file is None: - if weight_name is None: - weight_name = LORA_WEIGHT_NAME model_file = _get_model_file( pretrained_model_name_or_path_or_dict, - weights_name=weight_name, + weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, @@ -225,9 +225,10 @@ class UNet2DConditionLoadersMixin: self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, - weights_name: str = None, + weight_name: str = None, save_function: Callable = None, safe_serialization: bool = False, + **kwargs, ): r""" Save an attention processor to a directory, so that it can be re-loaded using the @@ -245,6 +246,12 @@ class UNet2DConditionLoadersMixin: need to replace `torch.save` by another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. """ + weight_name = weight_name or deprecate( + "weights_name", + "0.18.0", + "`weights_name` is deprecated, please use `weight_name` instead.", + take_from=kwargs, + ) if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -265,22 +272,13 @@ class UNet2DConditionLoadersMixin: # Save the model state_dict = model_to_save.state_dict() - # Clean the folder from a previous save - for filename in os.listdir(save_directory): - full_filename = os.path.join(save_directory, filename) - # If we have a shard file that is not going to be replaced, we delete it, but only from the main process - # in distributed settings to avoid race conditions. - weights_no_suffix = weights_name.replace(".bin", "") - if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process: - os.remove(full_filename) - - if weights_name is None: + if weight_name is None: if safe_serialization: - weights_name = LORA_WEIGHT_NAME_SAFE + weight_name = LORA_WEIGHT_NAME_SAFE else: - weights_name = LORA_WEIGHT_NAME + weight_name = LORA_WEIGHT_NAME # Save the model - save_function(state_dict, os.path.join(save_directory, weights_name)) + save_function(state_dict, os.path.join(save_directory, weight_name)) - logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")