[Lora] correct lora saving & loading (#2655)
* [Lora] correct lora saving & loading * fix final * Apply suggestions from code review
This commit is contained in:
parent
7c1b347702
commit
d185c0dfa7
|
@ -19,7 +19,7 @@ import torch
|
|||
|
||||
from .models.cross_attention import LoRACrossAttnProcessor
|
||||
from .models.modeling_utils import _get_model_file
|
||||
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
|
||||
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
|
@ -150,13 +150,14 @@ class UNet2DConditionLoadersMixin:
|
|||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"):
|
||||
if weight_name is None:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
# Let's first try to load .safetensors weights
|
||||
if (is_safetensors_available() and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
|
@ -169,14 +170,13 @@ class UNet2DConditionLoadersMixin:
|
|||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except EnvironmentError:
|
||||
if weight_name == LORA_WEIGHT_NAME_SAFE:
|
||||
weight_name = None
|
||||
# try loading non-safetensors weights
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
|
@ -225,9 +225,10 @@ class UNet2DConditionLoadersMixin:
|
|||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
weights_name: str = None,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Save an attention processor to a directory, so that it can be re-loaded using the
|
||||
|
@ -245,6 +246,12 @@ class UNet2DConditionLoadersMixin:
|
|||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
"""
|
||||
weight_name = weight_name or deprecate(
|
||||
"weights_name",
|
||||
"0.18.0",
|
||||
"`weights_name` is deprecated, please use `weight_name` instead.",
|
||||
take_from=kwargs,
|
||||
)
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
@ -265,22 +272,13 @@ class UNet2DConditionLoadersMixin:
|
|||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
weights_no_suffix = weights_name.replace(".bin", "")
|
||||
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
|
||||
os.remove(full_filename)
|
||||
|
||||
if weights_name is None:
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weights_name = LORA_WEIGHT_NAME_SAFE
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weights_name = LORA_WEIGHT_NAME
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weights_name))
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
|
Loading…
Reference in New Issue