[Lora] correct lora saving & loading (#2655)

* [Lora] correct lora saving & loading

* fix final

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen 2023-03-14 17:55:43 +01:00 committed by GitHub
parent 7c1b347702
commit d185c0dfa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 25 deletions

View File

@ -19,7 +19,7 @@ import torch
from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
if is_safetensors_available():
@ -150,13 +150,14 @@ class UNet2DConditionLoadersMixin:
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"):
if weight_name is None:
weight_name = LORA_WEIGHT_NAME_SAFE
# Let's first try to load .safetensors weights
if (is_safetensors_available() and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
@ -169,14 +170,13 @@ class UNet2DConditionLoadersMixin:
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
if weight_name == LORA_WEIGHT_NAME_SAFE:
weight_name = None
# try loading non-safetensors weights
pass
if model_file is None:
if weight_name is None:
weight_name = LORA_WEIGHT_NAME
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
@ -225,9 +225,10 @@ class UNet2DConditionLoadersMixin:
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weights_name: str = None,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
**kwargs,
):
r"""
Save an attention processor to a directory, so that it can be re-loaded using the
@ -245,6 +246,12 @@ class UNet2DConditionLoadersMixin:
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
"""
weight_name = weight_name or deprecate(
"weights_name",
"0.18.0",
"`weights_name` is deprecated, please use `weight_name` instead.",
take_from=kwargs,
)
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
@ -265,22 +272,13 @@ class UNet2DConditionLoadersMixin:
# Save the model
state_dict = model_to_save.state_dict()
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)
if weights_name is None:
if weight_name is None:
if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE
weight_name = LORA_WEIGHT_NAME_SAFE
else:
weights_name = LORA_WEIGHT_NAME
weight_name = LORA_WEIGHT_NAME
# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")