Adding `use_safetensors` argument to give more control to users (#2123)
* Adding `use_safetensors` argument to give more control to users about which weights they use. * Doc style. * Rebased (not functional). * Rebased and functional with tests. * Style. * Apply suggestions from code review * Style. * Addressing comments. * Update tests/test_pipelines.py Co-authored-by: Will Berman <wlbberman@gmail.com> * Black ??? --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com>
This commit is contained in:
parent
e828232780
commit
d9227cf788
|
@ -142,6 +142,17 @@ class UNet2DConditionLoadersMixin:
|
|||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
|
@ -151,7 +162,7 @@ class UNet2DConditionLoadersMixin:
|
|||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (is_safetensors_available() and weight_name is None) or (
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
|
@ -169,10 +180,11 @@ class UNet2DConditionLoadersMixin:
|
|||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except EnvironmentError:
|
||||
except IOError as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
|
|
|
@ -392,6 +392,10 @@ class ModelMixin(torch.nn.Module):
|
|||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
use_safetensors (`bool`, *optional* ):
|
||||
If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to
|
||||
`None` (the default). The pipeline will load using `safetensors` if safetensors weights are available
|
||||
*and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -423,6 +427,17 @@ class ModelMixin(torch.nn.Module):
|
|||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
allow_pickle = True
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
|
@ -509,7 +524,7 @@ class ModelMixin(torch.nn.Module):
|
|||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
if is_safetensors_available():
|
||||
if use_safetensors:
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
|
@ -525,7 +540,9 @@ class ModelMixin(torch.nn.Module):
|
|||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
except: # noqa: E722
|
||||
except IOError as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
pass
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
|
|
|
@ -694,6 +694,10 @@ class DiffusionPipeline(ConfigMixin):
|
|||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
use_safetensors (`bool`, *optional* ):
|
||||
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
|
||||
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
|
||||
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||
|
@ -752,6 +756,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
|
@ -1068,6 +1073,17 @@ class DiffusionPipeline(ConfigMixin):
|
|||
from_flax = kwargs.pop("from_flax", False)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
allow_pickle = True
|
||||
|
||||
pipeline_is_cached = False
|
||||
allow_patterns = None
|
||||
|
@ -1123,9 +1139,17 @@ class DiffusionPipeline(ConfigMixin):
|
|||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
if (
|
||||
use_safetensors
|
||||
and not allow_pickle
|
||||
and not is_safetensors_compatible(model_filenames, variant=variant)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
|
||||
)
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
|
||||
elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
|
||||
|
|
|
@ -440,7 +440,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
# LoRA and no LoRA should NOT be the same
|
||||
assert (sample - old_sample).abs().max() > 1e-4
|
||||
|
||||
def test_lora_save_load_safetensors_load_torch(self):
|
||||
def test_lora_save_safetensors_load_torch(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
@ -475,6 +475,43 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
|
||||
|
||||
def test_lora_save_torch_force_load_safetensors_error(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
# Saving as torch, properly reloads with directly filename
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
with self.assertRaises(IOError) as e:
|
||||
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
|
||||
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
|
||||
|
||||
def test_lora_on_off(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
|
|
@ -108,6 +108,17 @@ class DownloadTests(unittest.TestCase):
|
|||
# We need to never convert this tiny model to safetensors for this test to pass
|
||||
assert not any(f.endswith(".safetensors") for f in files)
|
||||
|
||||
def test_force_safetensors_error(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
with self.assertRaises(EnvironmentError):
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe-no-safetensors",
|
||||
safety_checker=None,
|
||||
cache_dir=tmpdirname,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
def test_returned_cached_folder(self):
|
||||
prompt = "hello"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
|
|
Loading…
Reference in New Issue