From d761b58bfcd7213d61bbdb090f182694e17317c7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Mar 2023 11:56:10 +0100 Subject: [PATCH] [From pretrained] Speed-up loading from cache (#2515) * [From pretrained] Speed-up loading from cache * up * Fix more * fix one more bug * make style * bigger refactor * factor out function * Improve more * better * deprecate return cache folder * clean up * improve tests * up * upload * add nice tests * simplify * finish * correct * fix version * rename * Apply suggestions from code review Co-authored-by: Lucain * rename * correct doc string * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca * apply code suggestions * finish --------- Co-authored-by: Lucain Co-authored-by: Pedro Cuenca --- ..._original_stable_diffusion_to_diffusers.py | 4 +- setup.py | 4 +- src/diffusers/configuration_utils.py | 41 +- src/diffusers/dependency_versions_table.py | 3 +- src/diffusers/models/modeling_utils.py | 76 +- src/diffusers/pipelines/pipeline_utils.py | 694 ++++++++++++------ .../stable_diffusion/convert_from_ckpt.py | 2 +- src/diffusers/schedulers/scheduling_utils.py | 3 +- src/diffusers/utils/__init__.py | 6 +- src/diffusers/utils/hub_utils.py | 18 +- tests/test_modeling_common.py | 39 + tests/test_pipelines.py | 68 +- 12 files changed, 638 insertions(+), 320 deletions(-) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 15afbccb..b9073789 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -16,7 +16,7 @@ import argparse -from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt if __name__ == "__main__": @@ -125,7 +125,7 @@ if __name__ == "__main__": ) args = parser.parse_args() - pipe = load_pipeline_from_original_stable_diffusion_ckpt( + pipe = download_from_original_stable_diffusion_ckpt( checkpoint_path=args.checkpoint_path, original_config_file=args.original_config_file, image_size=args.image_size, diff --git a/setup.py b/setup.py index a029ce04..7f6f4e53 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,8 @@ _deps = [ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.10.0", + "huggingface-hub>=0.13.0", + "requests-mock==1.10.0", "importlib_metadata", "isort>=5.5.4", "jax>=0.2.8,!=0.3.2", @@ -192,6 +193,7 @@ extras["test"] = deps_list( "pytest", "pytest-timeout", "pytest-xdist", + "requests-mock", "safetensors", "sentencepiece", "scipy", diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 47b201b9..20b7b273 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -31,7 +31,15 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R from requests import HTTPError from . import __version__ -from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging +from .utils import ( + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + DummyObject, + deprecate, + extract_commit_hash, + http_user_agent, + logging, +) logger = logging.get_logger(__name__) @@ -231,7 +239,11 @@ class ConfigMixin: @classmethod def load_config( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + return_unused_kwargs=False, + return_commit_hash=False, + **kwargs, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: r""" Instantiate a Python class from a config dictionary @@ -271,6 +283,10 @@ class ConfigMixin: subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False): + Whether unused keyword arguments of the config shall be returned. + return_commit_hash (`bool`, *optional*, defaults to `False): + Whether the commit_hash of the loaded configuration shall be returned. @@ -295,8 +311,10 @@ class ConfigMixin: revision = kwargs.pop("revision", None) _ = kwargs.pop("mirror", None) subfolder = kwargs.pop("subfolder", None) + user_agent = kwargs.pop("user_agent", {}) - user_agent = {"file_type": "config"} + user_agent = {**user_agent, "file_type": "config"} + user_agent = http_user_agent(user_agent) pretrained_model_name_or_path = str(pretrained_model_name_or_path) @@ -336,7 +354,6 @@ class ConfigMixin: subfolder=subfolder, revision=revision, ) - except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" @@ -378,13 +395,23 @@ class ConfigMixin: try: # Load config dict config_dict = cls._dict_from_json_file(config_file) + + commit_hash = extract_commit_hash(config_file) except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") - if return_unused_kwargs: - return config_dict, kwargs + if not (return_unused_kwargs or return_commit_hash): + return config_dict - return config_dict + outputs = (config_dict,) + + if return_unused_kwargs: + outputs += (kwargs,) + + if return_commit_hash: + outputs += (commit_hash,) + + return outputs @staticmethod def _get_init_keys(cls): diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 4c8eaa5b..0bc67271 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -10,7 +10,8 @@ deps = { "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.10.0", + "huggingface-hub": "huggingface-hub>=0.13.0", + "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "isort": "isort>=5.5.4", "jax": "jax>=0.2.8,!=0.3.2", diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 47e1c210..a21e0954 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -458,18 +458,34 @@ class ModelMixin(torch.nn.Module): " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - - # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the - # Load model + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + user_agent=user_agent, + **kwargs, + ) + # load model model_file = None if from_flax: model_file = _get_model_file( @@ -484,20 +500,7 @@ class ModelMixin(torch.nn.Module): revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - config, unused_kwargs = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - **kwargs, + commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) @@ -520,6 +523,7 @@ class ModelMixin(torch.nn.Module): revision=revision, subfolder=subfolder, user_agent=user_agent, + commit_hash=commit_hash, ) except: # noqa: E722 pass @@ -536,25 +540,12 @@ class ModelMixin(torch.nn.Module): revision=revision, subfolder=subfolder, user_agent=user_agent, + commit_hash=commit_hash, ) if low_cpu_mem_usage: # Instantiate model with empty weights with accelerate.init_empty_weights(): - config, unused_kwargs = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - **kwargs, - ) model = cls.from_config(config, **unused_kwargs) # if device_map is None, load the state dict and move the params from meta device to the cpu @@ -593,20 +584,6 @@ class ModelMixin(torch.nn.Module): "error_msgs": [], } else: - config, unused_kwargs = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - **kwargs, - ) model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file, variant=variant) @@ -803,6 +780,7 @@ def _get_model_file( use_auth_token, user_agent, revision, + commit_hash=None, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isfile(pretrained_model_name_or_path): @@ -840,7 +818,7 @@ def _get_model_file( use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, - revision=revision, + revision=revision or commit_hash, ) warnings.warn( f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", @@ -865,7 +843,7 @@ def _get_model_file( use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, - revision=revision, + revision=revision or commit_hash, ) return model_file diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index e2762209..67757244 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import fnmatch import importlib import inspect import os @@ -26,7 +27,8 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch -from huggingface_hub import model_info, snapshot_download +from huggingface_hub import hf_hub_download, model_info, snapshot_download +from huggingface_hub.utils import send_telemetry from packaging import version from PIL import Image from tqdm.auto import tqdm @@ -47,7 +49,6 @@ from ..utils import ( BaseOutput, deprecate, get_class_from_dynamic_module, - http_user_agent, is_accelerate_available, is_accelerate_version, is_safetensors_available, @@ -179,8 +180,7 @@ def is_safetensors_compatible(filenames, variant=None) -> bool: return True -def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]: - filenames = set(sibling.rfilename for sibling in info.siblings) +def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, @@ -220,6 +220,177 @@ def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], return usable_filenames, variant_filenames +def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames): + info = model_info( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + revision=None, + ) + filenames = set(sibling.rfilename for sibling in info.siblings) + comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) + comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] + + if set(comp_model_filenames) == set(model_filenames): + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", + FutureWarning, + ) + else: + warnings.warn( + f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", + FutureWarning, + ) + + +def maybe_raise_or_warn( + library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module +): + """Simple helper method to raise or warn in case incorrect module has been passed""" + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warning( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + +def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module): + """Simple helper method to retrieve class object of module as well as potential parent class objects""" + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + + class_obj = getattr(pipeline_module, class_name) + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + return class_obj, class_candidates + + +def load_sub_model( + library_name: str, + class_name: str, + importable_classes: List[Any], + pipelines: Any, + is_pipeline_module: bool, + pipeline_class: Any, + torch_dtype: torch.dtype, + provider: Any, + sess_options: Any, + device_map: Optional[Union[Dict[str, torch.device], str]], + model_variants: Dict[str, str], + name: str, + from_flax: bool, + variant: str, + low_cpu_mem_usage: bool, + cached_folder: Union[str, os.PathLike], +): + """Helper method to load the module `name` from `library_name` and `class_name`""" + # retrieve class candidates + class_obj, class_candidates = get_class_obj_and_candidates( + library_name, class_name, importable_classes, pipelines, is_pipeline_module + ) + + load_method_name = None + # retrive load method name + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + # if load method name is None, then we have a dummy module -> raise Error + if load_method_name is None: + none_module = class_obj.__module__ + is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( + TRANSFORMERS_DUMMY_MODULES_FOLDER + ) + if is_dummy_path and "dummy" in none_module: + # call class_obj for nice error message of missing requirements + class_obj() + + raise ValueError( + f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" + f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." + ) + + load_method = getattr(class_obj, load_method_name) + + # add kwargs to loading method + loading_kwargs = {} + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + + is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. + # This makes sure that the weights won't be initialized which significantly speeds up loading. + if is_diffusers_model or is_transformers_model: + loading_kwargs["device_map"] = device_map + loading_kwargs["variant"] = model_variants.pop(name, None) + if from_flax: + loading_kwargs["from_flax"] = True + + # the following can be deleted once the minimum required `transformers` version + # is higher than 4.27 + if ( + is_transformers_model + and loading_kwargs["variant"] is not None + and transformers_version < version.parse("4.27.0") + ): + raise ImportError( + f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" + ) + elif is_transformers_model and loading_kwargs["variant"] is None: + loading_kwargs.pop("variant") + + # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` + if not (from_flax and is_transformers_model): + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + else: + loading_kwargs["low_cpu_mem_usage"] = False + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + return loaded_sub_model + + class DiffusionPipeline(ConfigMixin): r""" Base class for all models. @@ -524,8 +695,6 @@ class DiffusionPipeline(ConfigMixin): also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. - return_cached_folder (`bool`, *optional*, defaults to `False`): - If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the specific pipeline class. The overwritten components are then directly passed to the pipelines @@ -583,13 +752,12 @@ class DiffusionPipeline(ConfigMixin): sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - return_cached_folder = kwargs.pop("return_cached_folder", False) variant = kwargs.pop("variant", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.load_config( + cached_folder = cls.download( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, @@ -598,117 +766,18 @@ class DiffusionPipeline(ConfigMixin): local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, - ) - - # retrieve all folder_names that contain relevant files - folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] - - if not local_files_only: - info = model_info( - pretrained_model_name_or_path, - use_auth_token=use_auth_token, - revision=revision, - ) - model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant) - model_folder_names = set([os.path.split(f)[0] for f in model_filenames]) - - if revision in DEPRECATED_REVISION_ARGS and version.parse( - version.parse(__version__).base_version - ) >= version.parse("0.17.0"): - info = model_info( - pretrained_model_name_or_path, - use_auth_token=use_auth_token, - revision=None, - ) - comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision) - comp_model_filenames = [ - ".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames - ] - - if set(comp_model_filenames) == set(model_filenames): - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", - FutureWarning, - ) - else: - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", - FutureWarning, - ) - - # all filenames compatible with variant will be added - allow_patterns = list(model_filenames) - - # allow all patterns from non-model folders - # this enables downloading schedulers, tokenizers, ... - allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names] - # also allow downloading config.jsons with the model - allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names] - - allow_patterns += [ - SCHEDULER_CONFIG_NAME, - CONFIG_NAME, - cls.config_name, - CUSTOM_PIPELINE_FILE_NAME, - ] - - if from_flax: - ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant): - ignore_patterns = ["*.bin", "*.msgpack"] - - safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")]) - safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")]) - if ( - len(safetensors_variant_filenames) > 0 - and safetensors_model_filenames != safetensors_variant_filenames - ): - logger.warn( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) - - else: - ignore_patterns = ["*.safetensors", "*.msgpack"] - - bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")]) - bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")]) - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: - logger.warn( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) - - else: - # allow everything since it has to be downloaded anyways - ignore_patterns = allow_patterns = None - - if cls != DiffusionPipeline: - requested_pipeline_class = cls.__name__ - else: - requested_pipeline_class = config_dict.get("_class_name", cls.__name__) - user_agent = {"pipeline_class": requested_pipeline_class} - if custom_pipeline is not None and not custom_pipeline.endswith(".py"): - user_agent["custom_pipeline"] = custom_pipeline - - user_agent = http_user_agent(user_agent) - - # download all allow_patterns - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - user_agent=user_agent, + from_flax=from_flax, + custom_pipeline=custom_pipeline, + variant=variant, ) else: cached_folder = pretrained_model_name_or_path - config_dict = cls.load_config(cached_folder) - # retrieve which subfolders should load variants + config_dict = cls.load_config(cached_folder) + + # 2. Define which model components should load variants + # We retrieve the information by matching whether variant + # model checkpoints exist in the subfolders model_variants = {} if variant is not None: for folder in os.listdir(cached_folder): @@ -718,7 +787,7 @@ class DiffusionPipeline(ConfigMixin): if variant_exists: model_variants[folder] = variant - # 2. Load the pipeline class, if using custom module then load it from the hub + # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if custom_pipeline is not None: if custom_pipeline.endswith(".py"): @@ -738,7 +807,7 @@ class DiffusionPipeline(ConfigMixin): diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) - # To be removed in 1.0.0 + # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( version.parse(config_dict["_diffusers_version"]).base_version ) <= version.parse("0.5.1"): @@ -757,6 +826,9 @@ class DiffusionPipeline(ConfigMixin): ) deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + # 4. Define expected modules given pipeline signature + # and define non-None initialized modules (=`init_kwargs`) + # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here @@ -788,6 +860,7 @@ class DiffusionPipeline(ConfigMixin): " separately if you need it." ) + # 5. Throw nice warnings / errors for fast accelerate loading if len(unused_kwargs) > 0: logger.warning( f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." @@ -823,135 +896,50 @@ class DiffusionPipeline(ConfigMixin): # import it here to avoid circular import from diffusers import pipelines - # 3. Load each module in the pipeline + # 6. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): - # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names if class_name.startswith("Flax"): class_name = class_name[4:] + # 6.2 Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) + importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name] loaded_sub_model = None - # if the model is in a pipeline module, then we load it from the pipeline + # 6.3 Use passed sub model or load class_name from library_name if name in passed_class_obj: - # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module: - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - expected_class_obj = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - expected_class_obj = class_candidate - - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): - raise ValueError( - f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" - f" {expected_class_obj}" - ) - else: - logger.warning( - f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" - " has the correct type" - ) - - # set passed class object - loaded_sub_model = passed_class_obj[name] - elif is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - class_obj = getattr(pipeline_module, class_name) - importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in importable_classes.keys()} - else: - # else we just import it from the library. - library = importlib.import_module(library_name) - - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - if loaded_sub_model is None: - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] - - if load_method_name is None: - none_module = class_obj.__module__ - is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( - TRANSFORMERS_DUMMY_MODULES_FOLDER - ) - if is_dummy_path and "dummy" in none_module: - # call class_obj for nice error message of missing requirements - class_obj() - - raise ValueError( - f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" - f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." - ) - - load_method = getattr(class_obj, load_method_name) - loading_kwargs = {} - - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - loading_kwargs["sess_options"] = sess_options - - is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) - - if is_transformers_available(): - transformers_version = version.parse(version.parse(transformers.__version__).base_version) - else: - transformers_version = "N/A" - - is_transformers_model = ( - is_transformers_available() - and issubclass(class_obj, PreTrainedModel) - and transformers_version >= version.parse("4.20.0") + # if the model is in a pipeline module, then we load it from the pipeline + # check that passed_class_obj has correct parent class + maybe_raise_or_warn( + library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ) - # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. - # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. - # This makes sure that the weights won't be initialized which significantly speeds up loading. - if is_diffusers_model or is_transformers_model: - loading_kwargs["device_map"] = device_map - loading_kwargs["variant"] = model_variants.pop(name, None) - if from_flax: - loading_kwargs["from_flax"] = True - - # the following can be deleted once the minimum required `transformers` version - # is higher than 4.27 - if ( - is_transformers_model - and loading_kwargs["variant"] is not None - and transformers_version < version.parse("4.27.0") - ): - raise ImportError( - f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" - ) - elif is_transformers_model and loading_kwargs["variant"] is None: - loading_kwargs.pop("variant") - - # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` - if not (from_flax and is_transformers_model): - loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - else: - loading_kwargs["low_cpu_mem_usage"] = False - - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder, **loading_kwargs) + loaded_sub_model = passed_class_obj[name] + else: + # load sub model + loaded_sub_model = load_sub_model( + library_name=library_name, + class_name=class_name, + importable_classes=importable_classes, + pipelines=pipelines, + is_pipeline_module=is_pipeline_module, + pipeline_class=pipeline_class, + torch_dtype=torch_dtype, + provider=provider, + sess_options=sess_options, + device_map=device_map, + model_variants=model_variants, + name=name, + from_flax=from_flax, + variant=variant, + low_cpu_mem_usage=low_cpu_mem_usage, + cached_folder=cached_folder, + ) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - # 4. Potentially add passed objects if expected + # 7. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) optional_modules = pipeline_class._optional_components @@ -964,13 +952,251 @@ class DiffusionPipeline(ConfigMixin): f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) - # 5. Instantiate the pipeline + # 8. Instantiate the pipeline model = pipeline_class(**init_kwargs) + return_cached_folder = kwargs.pop("return_cached_folder", False) if return_cached_folder: + message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`." + deprecate("return_cached_folder", "0.17.0", message, take_from=kwargs) return model, cached_folder + return model + @classmethod + def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: + r""" + Download and cache a PyTorch diffusion pipeline from pre-trained pipeline weights. + + Parameters: + pretrained_model_name (`str` or `os.PathLike`, *optional*): + Should be a string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + custom_pipeline (`str`, *optional*): + + + + This is an experimental feature and is likely to change in the future. + + + + Can be either: + + - A string, the *repo id* of a custom pipeline hosted inside a model repo on + https://huggingface.co/. Valid repo ids have to be located under a user or organization name, + like `hf-internal-testing/diffusers-dummy-pipeline`. + + + + It is required that the model repo has a file, called `pipeline.py` that defines the custom + pipeline. + + + + - A string, the *file name* of a community pipeline hosted on GitHub under + https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to + match exactly the file name without `.py` located under the above link, *e.g.* + `clip_guided_stable_diffusion`. + + + + Community pipelines are always loaded from the current `main` branch of GitHub. + + + + - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`. + + + + It is required that the directory has a file, called `pipeline.py` that defines the custom + pipeline. + + + + For more information on how to load and create custom pipelines, please have a look at [Loading and + Adding Custom + Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) + + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of + `diffusers` when loading from GitHub): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a + custom pipeline from GitHub. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` + + + + + + Activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#notice-on-telemetry-logging) to use this + method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_flax = kwargs.pop("from_flax", False) + custom_pipeline = kwargs.pop("custom_pipeline", None) + variant = kwargs.pop("variant", None) + + pipeline_is_cached = False + allow_patterns = None + ignore_patterns = None + + user_agent = {"pipeline_class": cls.__name__} + if custom_pipeline is not None and not custom_pipeline.endswith(".py"): + user_agent["custom_pipeline"] = custom_pipeline + + if not local_files_only: + info = model_info( + pretrained_model_name, + use_auth_token=use_auth_token, + revision=revision, + ) + user_agent["pretrained_model_name"] = pretrained_model_name + send_telemetry("pipelines", library_name="diffusers", library_version=__version__, user_agent=user_agent) + commit_hash = info.sha + + # try loading the config file + config_file = hf_hub_download( + pretrained_model_name, + cls.config_name, + cache_dir=cache_dir, + revision=commit_hash, + proxies=proxies, + force_download=force_download, + resume_download=resume_download, + use_auth_token=use_auth_token, + ) + + config_dict = cls._dict_from_json_file(config_file) + config_is_cached = True + + # retrieve all folder_names that contain relevant files + folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] + + filenames = set(sibling.rfilename for sibling in info.siblings) + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + + # if the whole pipeline is cached we don't have to ping the Hub + if revision in DEPRECATED_REVISION_ARGS and version.parse( + version.parse(__version__).base_version + ) >= version.parse("0.17.0"): + warn_deprecated_model_variant( + pretrained_model_name, use_auth_token, variant, revision, model_filenames + ) + + model_folder_names = set([os.path.split(f)[0] for f in model_filenames]) + + # all filenames compatible with variant will be added + allow_patterns = list(model_filenames) + + # allow all patterns from non-model folders + # this enables downloading schedulers, tokenizers, ... + allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names] + # also allow downloading config.json files with the model + allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names] + + allow_patterns += [ + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + cls.config_name, + CUSTOM_PIPELINE_FILE_NAME, + ] + + if from_flax: + ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] + elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant): + ignore_patterns = ["*.bin", "*.msgpack"] + + safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")]) + safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")]) + if ( + len(safetensors_variant_filenames) > 0 + and safetensors_model_filenames != safetensors_variant_filenames + ): + logger.warn( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." + ) + else: + ignore_patterns = ["*.safetensors", "*.msgpack"] + + bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")]) + bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")]) + if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: + logger.warn( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." + ) + + if config_is_cached: + re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] + re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] + + expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)] + expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)] + + snapshot_folder = Path(config_file).parent + pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files) + + if pipeline_is_cached: + # if the pipeline is cached, we can directly return it + # else call snapshot_download + return snapshot_folder + + # download all allow_patterns - ignore_patterns + cached_folder = snapshot_download( + pretrained_model_name, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + + return cached_folder + @staticmethod def _get_signature_keys(obj): parameters = inspect.signature(obj.__init__).parameters diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index bc3d439b..857fb296 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -954,7 +954,7 @@ def stable_unclip_image_noising_components( return image_normalizer, image_noising_scheduler -def load_pipeline_from_original_stable_diffusion_ckpt( +def download_from_original_stable_diffusion_ckpt( checkpoint_path: str, original_config_file: str = None, image_size: int = 512, diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 386d60b2..a4121f75 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -136,10 +136,11 @@ class SchedulerMixin: """ - config, kwargs = cls.load_config( + config, kwargs, commit_hash = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, return_unused_kwargs=True, + return_commit_hash=True, **kwargs, ) return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 64d5c695..196b3b02 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -35,7 +35,11 @@ from .constants import ( from .deprecation_utils import deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module -from .hub_utils import HF_HUB_OFFLINE, http_user_agent +from .hub_utils import ( + HF_HUB_OFFLINE, + extract_commit_hash, + http_user_agent, +) from .import_utils import ( ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 6ddbac36..916b18d3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -15,6 +15,7 @@ import os +import re import sys import traceback from pathlib import Path @@ -22,6 +23,7 @@ from typing import Dict, Optional, Union from uuid import uuid4 from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami +from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.utils import is_jinja_available from .. import __version__ @@ -132,6 +134,20 @@ def create_model_card(args, model_name): model_card.save(card_path) +def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): + """ + Extracts the commit hash from a resolved filename toward a cache file. + """ + if resolved_file is None or commit_hash is not None: + return commit_hash + resolved_file = str(Path(resolved_file).as_posix()) + search = re.search(r"snapshots/([^/]+)/", resolved_file) + if search is None: + return None + commit_hash = search.groups()[0] + return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None + + # Old default cache path, potentially to be migrated. # This logic was more or less taken from `transformers`, with the following differences: # - Diffusers doesn't use custom environment variables to specify the cache path. @@ -150,7 +166,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] old_cache_dir = Path(old_cache_dir).expanduser() new_cache_dir = Path(new_cache_dir).expanduser() - for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob + for old_blob_path in old_cache_dir.glob("**/blobs/*"): if old_blob_path.is_file() and not old_blob_path.is_symlink(): new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) new_blob_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a960df0c..e9b7d5f3 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -20,6 +20,7 @@ import unittest.mock as mock from typing import Dict, List, Tuple import numpy as np +import requests_mock import torch from requests.exceptions import HTTPError @@ -29,6 +30,13 @@ from diffusers.utils import torch_device class ModelUtilsTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + + import diffusers + + diffusers.utils.import_utils._safetensors_available = True + def test_accelerate_loading_error_message(self): with self.assertRaises(ValueError) as error_context: UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") @@ -60,6 +68,37 @@ class ModelUtilsTest(unittest.TestCase): if p1.data.ne(p2.data).sum() > 0: assert False, "Parameters not the same!" + def test_one_request_upon_cached(self): + # TODO: For some reason this test fails on MPS where no HEAD call is made. + if torch_device == "mps": + return + + import diffusers + + diffusers.utils.import_utils._safetensors_available = False + + with tempfile.TemporaryDirectory() as tmpdirname: + with requests_mock.mock(real_http=True) as m: + UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname + ) + + download_requests = [r.method for r in m.request_history] + assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model" + assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" + + with requests_mock.mock(real_http=True) as m: + UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname + ) + + cache_requests = [r.method for r in m.request_history] + assert ( + "HEAD" == cache_requests[0] and len(cache_requests) == 1 + ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + + diffusers.utils.import_utils._safetensors_available = True + class ModelTesterMixin: def test_from_save_pretrained(self): diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 5f9d0aa9..211a7c28 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -25,6 +25,7 @@ import unittest.mock as mock import numpy as np import PIL +import requests_mock import safetensors.torch import torch from parameterized import parameterized @@ -61,14 +62,44 @@ torch.backends.cuda.matmul.allow_tf32 = False class DownloadTests(unittest.TestCase): + def test_one_request_upon_cached(self): + # TODO: For some reason this test fails on MPS where no HEAD call is made. + if torch_device == "mps": + return + + with tempfile.TemporaryDirectory() as tmpdirname: + with requests_mock.mock(real_http=True) as m: + DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + download_requests = [r.method for r in m.request_history] + assert download_requests.count("HEAD") == 16, "15 calls to files + send_telemetry" + assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json" + assert ( + len(download_requests) == 33 + ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json" + + with requests_mock.mock(real_http=True) as m: + DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + cache_requests = [r.method for r in m.request_history] + assert cache_requests.count("HEAD") == 1, "send_telemetry is only HEAD" + assert cache_requests.count("GET") == 1, "model info is only GET" + assert ( + len(cache_requests) == 2 + ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" + def test_download_only_pytorch(self): with tempfile.TemporaryDirectory() as tmpdirname: # pipeline has Flax weights - _ = DiffusionPipeline.from_pretrained( + tmpdirname = DiffusionPipeline.download( "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname ) - all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] files = [item for sublist in all_root_files for item in sublist] # None of the downloaded files should be a flax file even if we have some here: @@ -101,13 +132,13 @@ class DownloadTests(unittest.TestCase): def test_download_safetensors(self): with tempfile.TemporaryDirectory() as tmpdirname: # pipeline has Flax weights - _ = DiffusionPipeline.from_pretrained( + tmpdirname = DiffusionPipeline.download( "hf-internal-testing/tiny-stable-diffusion-pipe-safetensors", safety_checker=None, cache_dir=tmpdirname, ) - all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] files = [item for sublist in all_root_files for item in sublist] # None of the downloaded files should be a pytorch file even if we have some here: @@ -204,12 +235,10 @@ class DownloadTests(unittest.TestCase): other_format = ".bin" if safe_avail else ".safetensors" with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionPipeline.from_pretrained( + tmpdirname = StableDiffusionPipeline.download( "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname ) - all_root_files = [ - t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")) - ] + all_root_files = [t[-1] for t in os.walk(tmpdirname)] files = [item for sublist in all_root_files for item in sublist] # None of the downloaded files should be a variant file even if we have some here: @@ -232,12 +261,10 @@ class DownloadTests(unittest.TestCase): variant = "fp16" with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionPipeline.from_pretrained( + tmpdirname = StableDiffusionPipeline.download( "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant ) - all_root_files = [ - t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")) - ] + all_root_files = [t[-1] for t in os.walk(tmpdirname)] files = [item for sublist in all_root_files for item in sublist] # None of the downloaded files should be a non-variant file even if we have some here: @@ -262,14 +289,13 @@ class DownloadTests(unittest.TestCase): variant = "no_ema" with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionPipeline.from_pretrained( + tmpdirname = StableDiffusionPipeline.download( "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant ) - snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots") - all_root_files = [t[-1] for t in os.walk(snapshots)] + all_root_files = [t[-1] for t in os.walk(tmpdirname)] files = [item for sublist in all_root_files for item in sublist] - unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet")) + unet_files = os.listdir(os.path.join(tmpdirname, "unet")) # Some of the downloaded files should be a non-variant file, check: # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet @@ -292,7 +318,7 @@ class DownloadTests(unittest.TestCase): for variant in [None, "no_ema"]: with self.assertRaises(OSError) as error_context: with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionPipeline.from_pretrained( + tmpdirname = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant=variant, @@ -302,13 +328,11 @@ class DownloadTests(unittest.TestCase): # text encoder has fp16 variants so we can load it with tempfile.TemporaryDirectory() as tmpdirname: - pipe = StableDiffusionPipeline.from_pretrained( + tmpdirname = StableDiffusionPipeline.download( "hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16" ) - assert pipe is not None - snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots") - all_root_files = [t[-1] for t in os.walk(snapshots)] + all_root_files = [t[-1] for t in os.walk(tmpdirname)] files = [item for sublist in all_root_files for item in sublist] # None of the downloaded files should be a non-variant file even if we have some here: @@ -395,7 +419,7 @@ class CustomPipelineTests(unittest.TestCase): @slow @require_torch_gpu - def test_load_pipeline_from_git(self): + def test_download_from_git(self): clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)