diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9848ce79..31fdc46d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -142,6 +142,17 @@ class UNet2DConditionLoadersMixin: revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True user_agent = { "file_type": "attn_procs_weights", @@ -151,7 +162,7 @@ class UNet2DConditionLoadersMixin: model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights - if (is_safetensors_available() and weight_name is None) or ( + if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: @@ -169,10 +180,11 @@ class UNet2DConditionLoadersMixin: user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") - except EnvironmentError: + except IOError as e: + if not allow_pickle: + raise e # try loading non-safetensors weights pass - if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a21e0954..a7c2c750 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -392,6 +392,10 @@ class ModelMixin(torch.nn.Module): variant (`str`, *optional*): If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is ignored when using `from_flax`. + use_safetensors (`bool`, *optional* ): + If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to + `None` (the default). The pipeline will load using `safetensors` if safetensors weights are available + *and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`. @@ -423,6 +427,17 @@ class ModelMixin(torch.nn.Module): device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False @@ -509,7 +524,7 @@ class ModelMixin(torch.nn.Module): model = load_flax_checkpoint_in_pytorch_model(model, model_file) else: - if is_safetensors_available(): + if use_safetensors: try: model_file = _get_model_file( pretrained_model_name_or_path, @@ -525,7 +540,9 @@ class ModelMixin(torch.nn.Module): user_agent=user_agent, commit_hash=commit_hash, ) - except: # noqa: E722 + except IOError as e: + if not allow_pickle: + raise e pass if model_file is None: model_file = _get_model_file( diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 917d2dca..6560f305 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -694,6 +694,10 @@ class DiffusionPipeline(ConfigMixin): also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. + use_safetensors (`bool`, *optional* ): + If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the + default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if + `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the specific pipeline class. The overwritten components are then directly passed to the pipelines @@ -752,6 +756,7 @@ class DiffusionPipeline(ConfigMixin): device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) + kwargs.pop("use_safetensors", None if is_safetensors_available() else False) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -1068,6 +1073,17 @@ class DiffusionPipeline(ConfigMixin): from_flax = kwargs.pop("from_flax", False) custom_pipeline = kwargs.pop("custom_pipeline", None) variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True pipeline_is_cached = False allow_patterns = None @@ -1123,9 +1139,17 @@ class DiffusionPipeline(ConfigMixin): CUSTOM_PIPELINE_FILE_NAME, ] + if ( + use_safetensors + and not allow_pickle + and not is_safetensors_compatible(model_filenames, variant=variant) + ): + raise EnvironmentError( + f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})" + ) if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant): + elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant): ignore_patterns = ["*.bin", "*.msgpack"] safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")]) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 24707df9..5ca0bae6 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -440,7 +440,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4 - def test_lora_save_load_safetensors_load_torch(self): + def test_lora_save_safetensors_load_torch(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -475,6 +475,43 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): new_model.to(torch_device) new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") + def test_lora_save_torch_force_load_safetensors_error(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + model.set_attn_processor(lora_attn_procs) + # Saving as torch, properly reloads with directly filename + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + with self.assertRaises(IOError) as e: + new_model.load_attn_procs(tmpdirname, use_safetensors=True) + self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception)) + def test_lora_on_off(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index b61d097d..daf88417 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -108,6 +108,17 @@ class DownloadTests(unittest.TestCase): # We need to never convert this tiny model to safetensors for this test to pass assert not any(f.endswith(".safetensors") for f in files) + def test_force_safetensors_error(self): + with tempfile.TemporaryDirectory() as tmpdirname: + # pipeline has Flax weights + with self.assertRaises(EnvironmentError): + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe-no-safetensors", + safety_checker=None, + cache_dir=tmpdirname, + use_safetensors=True, + ) + def test_returned_cached_folder(self): prompt = "hello" pipe = StableDiffusionPipeline.from_pretrained(