Adding `use_safetensors` argument to give more control to users (#2123)

* Adding `use_safetensors` argument to give more control to users

about which weights they use.

* Doc style.

* Rebased (not functional).

* Rebased and functional with tests.

* Style.

* Apply suggestions from code review

* Style.

* Addressing comments.

* Update tests/test_pipelines.py

Co-authored-by: Will Berman <wlbberman@gmail.com>

* Black ???

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
This commit is contained in:
Nicolas Patry 2023-03-16 15:57:43 +01:00 committed by GitHub
parent e828232780
commit d9227cf788
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 108 additions and 7 deletions

View File

@ -142,6 +142,17 @@ class UNet2DConditionLoadersMixin:
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
@ -151,7 +162,7 @@ class UNet2DConditionLoadersMixin:
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (is_safetensors_available() and weight_name is None) or (
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
@ -169,10 +180,11 @@ class UNet2DConditionLoadersMixin:
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
except IOError as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,

View File

@ -392,6 +392,10 @@ class ModelMixin(torch.nn.Module):
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
use_safetensors (`bool`, *optional* ):
If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to
`None` (the default). The pipeline will load using `safetensors` if safetensors weights are available
*and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
<Tip>
@ -423,6 +427,17 @@ class ModelMixin(torch.nn.Module):
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
@ -509,7 +524,7 @@ class ModelMixin(torch.nn.Module):
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if is_safetensors_available():
if use_safetensors:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
@ -525,7 +540,9 @@ class ModelMixin(torch.nn.Module):
user_agent=user_agent,
commit_hash=commit_hash,
)
except: # noqa: E722
except IOError as e:
if not allow_pickle:
raise e
pass
if model_file is None:
model_file = _get_model_file(

View File

@ -694,6 +694,10 @@ class DiffusionPipeline(ConfigMixin):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
use_safetensors (`bool`, *optional* ):
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
@ -752,6 +756,7 @@ class DiffusionPipeline(ConfigMixin):
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@ -1068,6 +1073,17 @@ class DiffusionPipeline(ConfigMixin):
from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True
pipeline_is_cached = False
allow_patterns = None
@ -1123,9 +1139,17 @@ class DiffusionPipeline(ConfigMixin):
CUSTOM_PIPELINE_FILE_NAME,
]
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, variant=variant)
):
raise EnvironmentError(
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"]
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])

View File

@ -440,7 +440,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_load_safetensors_load_torch(self):
def test_lora_save_safetensors_load_torch(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@ -475,6 +475,43 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
def test_lora_save_torch_force_load_safetensors_error(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
with self.assertRaises(IOError) as e:
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
def test_lora_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@ -108,6 +108,17 @@ class DownloadTests(unittest.TestCase):
# We need to never convert this tiny model to safetensors for this test to pass
assert not any(f.endswith(".safetensors") for f in files)
def test_force_safetensors_error(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
with self.assertRaises(EnvironmentError):
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-no-safetensors",
safety_checker=None,
cache_dir=tmpdirname,
use_safetensors=True,
)
def test_returned_cached_folder(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(