[From pretrained] Speed-up loading from cache (#2515)

* [From pretrained] Speed-up loading from cache

* up

* Fix more

* fix one more bug

* make style

* bigger refactor

* factor out function

* Improve more

* better

* deprecate return cache folder

* clean up

* improve tests

* up

* upload

* add nice tests

* simplify

* finish

* correct

* fix version

* rename

* Apply suggestions from code review

Co-authored-by: Lucain <lucainp@gmail.com>

* rename

* correct doc string

* correct more

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* apply code suggestions

* finish

---------

Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen 2023-03-10 11:56:10 +01:00 committed by GitHub
parent 7fe638c502
commit d761b58bfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 638 additions and 320 deletions

View File

@ -16,7 +16,7 @@
import argparse import argparse
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
if __name__ == "__main__": if __name__ == "__main__":
@ -125,7 +125,7 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
pipe = load_pipeline_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=args.checkpoint_path, checkpoint_path=args.checkpoint_path,
original_config_file=args.original_config_file, original_config_file=args.original_config_file,
image_size=args.image_size, image_size=args.image_size,

View File

@ -86,7 +86,8 @@ _deps = [
"filelock", "filelock",
"flax>=0.4.1", "flax>=0.4.1",
"hf-doc-builder>=0.3.0", "hf-doc-builder>=0.3.0",
"huggingface-hub>=0.10.0", "huggingface-hub>=0.13.0",
"requests-mock==1.10.0",
"importlib_metadata", "importlib_metadata",
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.8,!=0.3.2", "jax>=0.2.8,!=0.3.2",
@ -192,6 +193,7 @@ extras["test"] = deps_list(
"pytest", "pytest",
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"requests-mock",
"safetensors", "safetensors",
"sentencepiece", "sentencepiece",
"scipy", "scipy",

View File

@ -31,7 +31,15 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError from requests import HTTPError
from . import __version__ from . import __version__
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DummyObject,
deprecate,
extract_commit_hash,
http_user_agent,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -231,7 +239,11 @@ class ConfigMixin:
@classmethod @classmethod
def load_config( def load_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
return_unused_kwargs=False,
return_commit_hash=False,
**kwargs,
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r""" r"""
Instantiate a Python class from a config dictionary Instantiate a Python class from a config dictionary
@ -271,6 +283,10 @@ class ConfigMixin:
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here. huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False):
Whether unused keyword arguments of the config shall be returned.
return_commit_hash (`bool`, *optional*, defaults to `False):
Whether the commit_hash of the loaded configuration shall be returned.
<Tip> <Tip>
@ -295,8 +311,10 @@ class ConfigMixin:
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
user_agent = {"file_type": "config"} user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
@ -336,7 +354,6 @@ class ConfigMixin:
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision,
) )
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
@ -378,13 +395,23 @@ class ConfigMixin:
try: try:
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(config_file) config_dict = cls._dict_from_json_file(config_file)
commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
if return_unused_kwargs: if not (return_unused_kwargs or return_commit_hash):
return config_dict, kwargs return config_dict
return config_dict outputs = (config_dict,)
if return_unused_kwargs:
outputs += (kwargs,)
if return_commit_hash:
outputs += (commit_hash,)
return outputs
@staticmethod @staticmethod
def _get_init_keys(cls): def _get_init_keys(cls):

View File

@ -10,7 +10,8 @@ deps = {
"filelock": "filelock", "filelock": "filelock",
"flax": "flax>=0.4.1", "flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0", "hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.10.0", "huggingface-hub": "huggingface-hub>=0.13.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2", "jax": "jax>=0.2.8,!=0.3.2",

View File

@ -458,18 +458,34 @@ class ModelMixin(torch.nn.Module):
" dispatching. Please make sure to set `low_cpu_mem_usage=True`." " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
) )
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
user_agent = { user_agent = {
"diffusers": __version__, "diffusers": __version__,
"file_type": "model", "file_type": "model",
"framework": "pytorch", "framework": "pytorch",
} }
# Load config if we don't provide a configuration # load config
config_path = pretrained_model_name_or_path config, unused_kwargs, commit_hash = cls.load_config(
config_path,
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the cache_dir=cache_dir,
# Load model return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
user_agent=user_agent,
**kwargs,
)
# load model
model_file = None model_file = None
if from_flax: if from_flax:
model_file = _get_model_file( model_file = _get_model_file(
@ -484,20 +500,7 @@ class ModelMixin(torch.nn.Module):
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
) commit_hash=commit_hash,
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
) )
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
@ -520,6 +523,7 @@ class ModelMixin(torch.nn.Module):
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash,
) )
except: # noqa: E722 except: # noqa: E722
pass pass
@ -536,25 +540,12 @@ class ModelMixin(torch.nn.Module):
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash,
) )
if low_cpu_mem_usage: if low_cpu_mem_usage:
# Instantiate model with empty weights # Instantiate model with empty weights
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
# if device_map is None, load the state dict and move the params from meta device to the cpu # if device_map is None, load the state dict and move the params from meta device to the cpu
@ -593,20 +584,6 @@ class ModelMixin(torch.nn.Module):
"error_msgs": [], "error_msgs": [],
} }
else: else:
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant)
@ -803,6 +780,7 @@ def _get_model_file(
use_auth_token, use_auth_token,
user_agent, user_agent,
revision, revision,
commit_hash=None,
): ):
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path):
@ -840,7 +818,7 @@ def _get_model_file(
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision or commit_hash,
) )
warnings.warn( warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
@ -865,7 +843,7 @@ def _get_model_file(
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision or commit_hash,
) )
return model_file return model_file

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import fnmatch
import importlib import importlib
import inspect import inspect
import os import os
@ -26,7 +27,8 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from huggingface_hub import model_info, snapshot_download from huggingface_hub import hf_hub_download, model_info, snapshot_download
from huggingface_hub.utils import send_telemetry
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -47,7 +49,6 @@ from ..utils import (
BaseOutput, BaseOutput,
deprecate, deprecate,
get_class_from_dynamic_module, get_class_from_dynamic_module,
http_user_agent,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_safetensors_available, is_safetensors_available,
@ -179,8 +180,7 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
return True return True
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]: def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
filenames = set(sibling.rfilename for sibling in info.siblings)
weight_names = [ weight_names = [
WEIGHTS_NAME, WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
@ -220,6 +220,177 @@ def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike],
return usable_filenames, variant_filenames return usable_filenames, variant_filenames
def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames):
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=None,
)
filenames = set(sibling.rfilename for sibling in info.siblings)
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
if set(comp_model_filenames) == set(model_filenames):
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
else:
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
FutureWarning,
)
def maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
):
"""Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
return class_obj, class_candidates
def load_sub_model(
library_name: str,
class_name: str,
importable_classes: List[Any],
pipelines: Any,
is_pipeline_module: bool,
pipeline_class: Any,
torch_dtype: torch.dtype,
provider: Any,
sess_options: Any,
device_map: Optional[Union[Dict[str, torch.device], str]],
model_variants: Dict[str, str],
name: str,
from_flax: bool,
variant: str,
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module
)
load_method_name = None
# retrive load method name
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
# if load method name is None, then we have a dummy module -> raise Error
if load_method_name is None:
none_module = class_obj.__module__
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
TRANSFORMERS_DUMMY_MODULES_FOLDER
)
if is_dummy_path and "dummy" in none_module:
# call class_obj for nice error message of missing requirements
class_obj()
raise ValueError(
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)
load_method = getattr(class_obj, load_method_name)
# add kwargs to loading method
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map
loading_kwargs["variant"] = model_variants.pop(name, None)
if from_flax:
loading_kwargs["from_flax"] = True
# the following can be deleted once the minimum required `transformers` version
# is higher than 4.27
if (
is_transformers_model
and loading_kwargs["variant"] is not None
and transformers_version < version.parse("4.27.0")
):
raise ImportError(
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
)
elif is_transformers_model and loading_kwargs["variant"] is None:
loading_kwargs.pop("variant")
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
if not (from_flax and is_transformers_model):
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
else:
loading_kwargs["low_cpu_mem_usage"] = False
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
return loaded_sub_model
class DiffusionPipeline(ConfigMixin): class DiffusionPipeline(ConfigMixin):
r""" r"""
Base class for all models. Base class for all models.
@ -524,8 +695,6 @@ class DiffusionPipeline(ConfigMixin):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error. setting this argument to `True` will raise an error.
return_cached_folder (`bool`, *optional*, defaults to `False`):
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines specific pipeline class. The overwritten components are then directly passed to the pipelines
@ -583,13 +752,12 @@ class DiffusionPipeline(ConfigMixin):
sess_options = kwargs.pop("sess_options", None) sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
return_cached_folder = kwargs.pop("return_cached_folder", False)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.load_config( cached_folder = cls.download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
resume_download=resume_download, resume_download=resume_download,
@ -598,117 +766,18 @@ class DiffusionPipeline(ConfigMixin):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
) from_flax=from_flax,
custom_pipeline=custom_pipeline,
# retrieve all folder_names that contain relevant files variant=variant,
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
if not local_files_only:
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant)
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.17.0"):
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=None,
)
comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision)
comp_model_filenames = [
".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames
]
if set(comp_model_filenames) == set(model_filenames):
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
else:
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
FutureWarning,
)
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
# also allow downloading config.jsons with the model
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"]
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
# allow everything since it has to be downloaded anyways
ignore_patterns = allow_patterns = None
if cls != DiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
) )
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
config_dict = cls.load_config(cached_folder)
# retrieve which subfolders should load variants config_dict = cls.load_config(cached_folder)
# 2. Define which model components should load variants
# We retrieve the information by matching whether variant
# model checkpoints exist in the subfolders
model_variants = {} model_variants = {}
if variant is not None: if variant is not None:
for folder in os.listdir(cached_folder): for folder in os.listdir(cached_folder):
@ -718,7 +787,7 @@ class DiffusionPipeline(ConfigMixin):
if variant_exists: if variant_exists:
model_variants[folder] = variant model_variants[folder] = variant
# 2. Load the pipeline class, if using custom module then load it from the hub # 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
if custom_pipeline is not None: if custom_pipeline is not None:
if custom_pipeline.endswith(".py"): if custom_pipeline.endswith(".py"):
@ -738,7 +807,7 @@ class DiffusionPipeline(ConfigMixin):
diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# To be removed in 1.0.0 # DEPRECATED: To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version version.parse(config_dict["_diffusers_version"]).base_version
) <= version.parse("0.5.1"): ) <= version.parse("0.5.1"):
@ -757,6 +826,9 @@ class DiffusionPipeline(ConfigMixin):
) )
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
# 4. Define expected modules given pipeline signature
# and define non-None initialized modules (=`init_kwargs`)
# some modules can be passed directly to the init # some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs` # in this case they are already instantiated in `kwargs`
# extract them here # extract them here
@ -788,6 +860,7 @@ class DiffusionPipeline(ConfigMixin):
" separately if you need it." " separately if you need it."
) )
# 5. Throw nice warnings / errors for fast accelerate loading
if len(unused_kwargs) > 0: if len(unused_kwargs) > 0:
logger.warning( logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
@ -823,135 +896,50 @@ class DiffusionPipeline(ConfigMixin):
# import it here to avoid circular import # import it here to avoid circular import
from diffusers import pipelines from diffusers import pipelines
# 3. Load each module in the pipeline # 6. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"): if class_name.startswith("Flax"):
class_name = class_name[4:] class_name = class_name[4:]
# 6.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name]
loaded_sub_model = None loaded_sub_model = None
# if the model is in a pipeline module, then we load it from the pipeline # 6.3 Use passed sub model or load class_name from library_name
if name in passed_class_obj: if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class # if the model is in a pipeline module, then we load it from the pipeline
if not is_pipeline_module: # check that passed_class_obj has correct parent class
library = importlib.import_module(library_name) maybe_raise_or_warn(
class_obj = getattr(library, class_name) library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
# set passed class object
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
if load_method_name is None:
none_module = class_obj.__module__
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
TRANSFORMERS_DUMMY_MODULES_FOLDER
)
if is_dummy_path and "dummy" in none_module:
# call class_obj for nice error message of missing requirements
class_obj()
raise ValueError(
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)
load_method = getattr(class_obj, load_method_name)
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
) )
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. loaded_sub_model = passed_class_obj[name]
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. else:
# This makes sure that the weights won't be initialized which significantly speeds up loading. # load sub model
if is_diffusers_model or is_transformers_model: loaded_sub_model = load_sub_model(
loading_kwargs["device_map"] = device_map library_name=library_name,
loading_kwargs["variant"] = model_variants.pop(name, None) class_name=class_name,
if from_flax: importable_classes=importable_classes,
loading_kwargs["from_flax"] = True pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
# the following can be deleted once the minimum required `transformers` version pipeline_class=pipeline_class,
# is higher than 4.27 torch_dtype=torch_dtype,
if ( provider=provider,
is_transformers_model sess_options=sess_options,
and loading_kwargs["variant"] is not None device_map=device_map,
and transformers_version < version.parse("4.27.0") model_variants=model_variants,
): name=name,
raise ImportError( from_flax=from_flax,
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" variant=variant,
) low_cpu_mem_usage=low_cpu_mem_usage,
elif is_transformers_model and loading_kwargs["variant"] is None: cached_folder=cached_folder,
loading_kwargs.pop("variant") )
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
if not (from_flax and is_transformers_model):
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
else:
loading_kwargs["low_cpu_mem_usage"] = False
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 4. Potentially add passed objects if expected # 7. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys()) missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys()) passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components optional_modules = pipeline_class._optional_components
@ -964,13 +952,251 @@ class DiffusionPipeline(ConfigMixin):
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
) )
# 5. Instantiate the pipeline # 8. Instantiate the pipeline
model = pipeline_class(**init_kwargs) model = pipeline_class(**init_kwargs)
return_cached_folder = kwargs.pop("return_cached_folder", False)
if return_cached_folder: if return_cached_folder:
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
deprecate("return_cached_folder", "0.17.0", message, take_from=kwargs)
return model, cached_folder return model, cached_folder
return model return model
@classmethod
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
r"""
Download and cache a PyTorch diffusion pipeline from pre-trained pipeline weights.
Parameters:
pretrained_model_name (`str` or `os.PathLike`, *optional*):
Should be a string, the *repo id* of a pretrained pipeline hosted inside a model repo on
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
`CompVis/ldm-text2im-large-256`.
custom_pipeline (`str`, *optional*):
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
Can be either:
- A string, the *repo id* of a custom pipeline hosted inside a model repo on
https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
like `hf-internal-testing/diffusers-dummy-pipeline`.
<Tip>
It is required that the model repo has a file, called `pipeline.py` that defines the custom
pipeline.
</Tip>
- A string, the *file name* of a community pipeline hosted on GitHub under
https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
match exactly the file name without `.py` located under the above link, *e.g.*
`clip_guided_stable_diffusion`.
<Tip>
Community pipelines are always loaded from the current `main` branch of GitHub.
</Tip>
- A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
<Tip>
It is required that the directory has a file, called `pipeline.py` that defines the custom
pipeline.
</Tip>
For more information on how to load and create custom pipelines, please have a look at [Loading and
Adding Custom
Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of
`diffusers` when loading from GitHub):
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
`revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a
custom pipeline from GitHub.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. specify the folder name here.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
</Tip>
<Tip>
Activate the special
["offline-mode"](https://huggingface.co/diffusers/installation.html#notice-on-telemetry-logging) to use this
method in a firewalled environment.
</Tip>
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None)
variant = kwargs.pop("variant", None)
pipeline_is_cached = False
allow_patterns = None
ignore_patterns = None
user_agent = {"pipeline_class": cls.__name__}
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
user_agent["custom_pipeline"] = custom_pipeline
if not local_files_only:
info = model_info(
pretrained_model_name,
use_auth_token=use_auth_token,
revision=revision,
)
user_agent["pretrained_model_name"] = pretrained_model_name
send_telemetry("pipelines", library_name="diffusers", library_version=__version__, user_agent=user_agent)
commit_hash = info.sha
# try loading the config file
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=commit_hash,
proxies=proxies,
force_download=force_download,
resume_download=resume_download,
use_auth_token=use_auth_token,
)
config_dict = cls._dict_from_json_file(config_file)
config_is_cached = True
# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
filenames = set(sibling.rfilename for sibling in info.siblings)
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
# if the whole pipeline is cached we don't have to ping the Hub
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.17.0"):
warn_deprecated_model_variant(
pretrained_model_name, use_auth_token, variant, revision, model_filenames
)
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"]
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
if config_is_cached:
re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)]
expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)]
snapshot_folder = Path(config_file).parent
pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files)
if pipeline_is_cached:
# if the pipeline is cached, we can directly return it
# else call snapshot_download
return snapshot_folder
# download all allow_patterns - ignore_patterns
cached_folder = snapshot_download(
pretrained_model_name,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
return cached_folder
@staticmethod @staticmethod
def _get_signature_keys(obj): def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters parameters = inspect.signature(obj.__init__).parameters

View File

@ -954,7 +954,7 @@ def stable_unclip_image_noising_components(
return image_normalizer, image_noising_scheduler return image_normalizer, image_noising_scheduler
def load_pipeline_from_original_stable_diffusion_ckpt( def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path: str,
original_config_file: str = None, original_config_file: str = None,
image_size: int = 512, image_size: int = 512,

View File

@ -136,10 +136,11 @@ class SchedulerMixin:
</Tip> </Tip>
""" """
config, kwargs = cls.load_config( config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path, pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder, subfolder=subfolder,
return_unused_kwargs=True, return_unused_kwargs=True,
return_commit_hash=True,
**kwargs, **kwargs,
) )
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)

View File

@ -35,7 +35,11 @@ from .constants import (
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
from .doc_utils import replace_example_docstring from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import HF_HUB_OFFLINE, http_user_agent from .hub_utils import (
HF_HUB_OFFLINE,
extract_commit_hash,
http_user_agent,
)
from .import_utils import ( from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,

View File

@ -15,6 +15,7 @@
import os import os
import re
import sys import sys
import traceback import traceback
from pathlib import Path from pathlib import Path
@ -22,6 +23,7 @@ from typing import Dict, Optional, Union
from uuid import uuid4 from uuid import uuid4
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import is_jinja_available from huggingface_hub.utils import is_jinja_available
from .. import __version__ from .. import __version__
@ -132,6 +134,20 @@ def create_model_card(args, model_name):
model_card.save(card_path) model_card.save(card_path)
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
"""
Extracts the commit hash from a resolved filename toward a cache file.
"""
if resolved_file is None or commit_hash is not None:
return commit_hash
resolved_file = str(Path(resolved_file).as_posix())
search = re.search(r"snapshots/([^/]+)/", resolved_file)
if search is None:
return None
commit_hash = search.groups()[0]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
# Old default cache path, potentially to be migrated. # Old default cache path, potentially to be migrated.
# This logic was more or less taken from `transformers`, with the following differences: # This logic was more or less taken from `transformers`, with the following differences:
# - Diffusers doesn't use custom environment variables to specify the cache path. # - Diffusers doesn't use custom environment variables to specify the cache path.
@ -150,7 +166,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
old_cache_dir = Path(old_cache_dir).expanduser() old_cache_dir = Path(old_cache_dir).expanduser()
new_cache_dir = Path(new_cache_dir).expanduser() new_cache_dir = Path(new_cache_dir).expanduser()
for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob for old_blob_path in old_cache_dir.glob("**/blobs/*"):
if old_blob_path.is_file() and not old_blob_path.is_symlink(): if old_blob_path.is_file() and not old_blob_path.is_symlink():
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
new_blob_path.parent.mkdir(parents=True, exist_ok=True) new_blob_path.parent.mkdir(parents=True, exist_ok=True)

View File

@ -20,6 +20,7 @@ import unittest.mock as mock
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
import requests_mock
import torch import torch
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
@ -29,6 +30,13 @@ from diffusers.utils import torch_device
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
import diffusers
diffusers.utils.import_utils._safetensors_available = True
def test_accelerate_loading_error_message(self): def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context: with self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
@ -60,6 +68,37 @@ class ModelUtilsTest(unittest.TestCase):
if p1.data.ne(p2.data).sum() > 0: if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!" assert False, "Parameters not the same!"
def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
if torch_device == "mps":
return
import diffusers
diffusers.utils.import_utils._safetensors_available = False
with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
)
cache_requests = [r.method for r in m.request_history]
assert (
"HEAD" == cache_requests[0] and len(cache_requests) == 1
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
diffusers.utils.import_utils._safetensors_available = True
class ModelTesterMixin: class ModelTesterMixin:
def test_from_save_pretrained(self): def test_from_save_pretrained(self):

View File

@ -25,6 +25,7 @@ import unittest.mock as mock
import numpy as np import numpy as np
import PIL import PIL
import requests_mock
import safetensors.torch import safetensors.torch
import torch import torch
from parameterized import parameterized from parameterized import parameterized
@ -61,14 +62,44 @@ torch.backends.cuda.matmul.allow_tf32 = False
class DownloadTests(unittest.TestCase): class DownloadTests(unittest.TestCase):
def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
if torch_device == "mps":
return
with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 16, "15 calls to files + send_telemetry"
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
assert (
len(download_requests) == 33
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "send_telemetry is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
assert (
len(cache_requests) == 2
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
def test_download_only_pytorch(self): def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights # pipeline has Flax weights
_ = DiffusionPipeline.from_pretrained( tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
) )
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a flax file even if we have some here: # None of the downloaded files should be a flax file even if we have some here:
@ -101,13 +132,13 @@ class DownloadTests(unittest.TestCase):
def test_download_safetensors(self): def test_download_safetensors(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights # pipeline has Flax weights
_ = DiffusionPipeline.from_pretrained( tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors", "hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
safety_checker=None, safety_checker=None,
cache_dir=tmpdirname, cache_dir=tmpdirname,
) )
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a pytorch file even if we have some here: # None of the downloaded files should be a pytorch file even if we have some here:
@ -204,12 +235,10 @@ class DownloadTests(unittest.TestCase):
other_format = ".bin" if safe_avail else ".safetensors" other_format = ".bin" if safe_avail else ".safetensors"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
) )
all_root_files = [ all_root_files = [t[-1] for t in os.walk(tmpdirname)]
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a variant file even if we have some here: # None of the downloaded files should be a variant file even if we have some here:
@ -232,12 +261,10 @@ class DownloadTests(unittest.TestCase):
variant = "fp16" variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
) )
all_root_files = [ all_root_files = [t[-1] for t in os.walk(tmpdirname)]
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here: # None of the downloaded files should be a non-variant file even if we have some here:
@ -262,14 +289,13 @@ class DownloadTests(unittest.TestCase):
variant = "no_ema" variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant "hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
) )
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots") all_root_files = [t[-1] for t in os.walk(tmpdirname)]
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet")) unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
# Some of the downloaded files should be a non-variant file, check: # Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
@ -292,7 +318,7 @@ class DownloadTests(unittest.TestCase):
for variant in [None, "no_ema"]: for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context: with self.assertRaises(OSError) as error_context:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants", "hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname, cache_dir=tmpdirname,
variant=variant, variant=variant,
@ -302,13 +328,11 @@ class DownloadTests(unittest.TestCase):
# text encoder has fp16 variants so we can load it # text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe = StableDiffusionPipeline.from_pretrained( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16" "hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
) )
assert pipe is not None
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots") all_root_files = [t[-1] for t in os.walk(tmpdirname)]
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here: # None of the downloaded files should be a non-variant file even if we have some here:
@ -395,7 +419,7 @@ class CustomPipelineTests(unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_load_pipeline_from_git(self): def test_download_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)