[From pretrained] Speed-up loading from cache (#2515)
* [From pretrained] Speed-up loading from cache * up * Fix more * fix one more bug * make style * bigger refactor * factor out function * Improve more * better * deprecate return cache folder * clean up * improve tests * up * upload * add nice tests * simplify * finish * correct * fix version * rename * Apply suggestions from code review Co-authored-by: Lucain <lucainp@gmail.com> * rename * correct doc string * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * apply code suggestions * finish --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
7fe638c502
commit
d761b58bfc
|
@ -16,7 +16,7 @@
|
|||
|
||||
import argparse
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -125,7 +125,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
original_config_file=args.original_config_file,
|
||||
image_size=args.image_size,
|
||||
|
|
4
setup.py
4
setup.py
|
@ -86,7 +86,8 @@ _deps = [
|
|||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.10.0",
|
||||
"huggingface-hub>=0.13.0",
|
||||
"requests-mock==1.10.0",
|
||||
"importlib_metadata",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.8,!=0.3.2",
|
||||
|
@ -192,6 +193,7 @@ extras["test"] = deps_list(
|
|||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"requests-mock",
|
||||
"safetensors",
|
||||
"sentencepiece",
|
||||
"scipy",
|
||||
|
|
|
@ -31,7 +31,15 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
|
|||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
DummyObject,
|
||||
deprecate,
|
||||
extract_commit_hash,
|
||||
http_user_agent,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -231,7 +239,11 @@ class ConfigMixin:
|
|||
|
||||
@classmethod
|
||||
def load_config(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
return_unused_kwargs=False,
|
||||
return_commit_hash=False,
|
||||
**kwargs,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Instantiate a Python class from a config dictionary
|
||||
|
@ -271,6 +283,10 @@ class ConfigMixin:
|
|||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
||||
Whether unused keyword arguments of the config shall be returned.
|
||||
return_commit_hash (`bool`, *optional*, defaults to `False):
|
||||
Whether the commit_hash of the loaded configuration shall be returned.
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -295,8 +311,10 @@ class ConfigMixin:
|
|||
revision = kwargs.pop("revision", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
|
||||
user_agent = {"file_type": "config"}
|
||||
user_agent = {**user_agent, "file_type": "config"}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
|
||||
|
@ -336,7 +354,6 @@ class ConfigMixin:
|
|||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
||||
|
@ -378,14 +395,24 @@ class ConfigMixin:
|
|||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
|
||||
commit_hash = extract_commit_hash(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
||||
|
||||
if return_unused_kwargs:
|
||||
return config_dict, kwargs
|
||||
|
||||
if not (return_unused_kwargs or return_commit_hash):
|
||||
return config_dict
|
||||
|
||||
outputs = (config_dict,)
|
||||
|
||||
if return_unused_kwargs:
|
||||
outputs += (kwargs,)
|
||||
|
||||
if return_commit_hash:
|
||||
outputs += (commit_hash,)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _get_init_keys(cls):
|
||||
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||
|
|
|
@ -10,7 +10,8 @@ deps = {
|
|||
"filelock": "filelock",
|
||||
"flax": "flax>=0.4.1",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.10.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.13.0",
|
||||
"requests-mock": "requests-mock==1.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8,!=0.3.2",
|
||||
|
|
|
@ -458,18 +458,34 @@ class ModelMixin(torch.nn.Module):
|
|||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# Load model
|
||||
# load config
|
||||
config, unused_kwargs, commit_hash = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
user_agent=user_agent,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# load model
|
||||
model_file = None
|
||||
if from_flax:
|
||||
model_file = _get_model_file(
|
||||
|
@ -484,20 +500,7 @@ class ModelMixin(torch.nn.Module):
|
|||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
|
@ -520,6 +523,7 @@ class ModelMixin(torch.nn.Module):
|
|||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
except: # noqa: E722
|
||||
pass
|
||||
|
@ -536,25 +540,12 @@ class ModelMixin(torch.nn.Module):
|
|||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
|
@ -593,20 +584,6 @@ class ModelMixin(torch.nn.Module):
|
|||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
|
@ -803,6 +780,7 @@ def _get_model_file(
|
|||
use_auth_token,
|
||||
user_agent,
|
||||
revision,
|
||||
commit_hash=None,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
|
@ -840,7 +818,7 @@ def _get_model_file(
|
|||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
revision=revision or commit_hash,
|
||||
)
|
||||
warnings.warn(
|
||||
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
|
@ -865,7 +843,7 @@ def _get_model_file(
|
|||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
revision=revision or commit_hash,
|
||||
)
|
||||
return model_file
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
|
@ -26,7 +27,8 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from huggingface_hub import model_info, snapshot_download
|
||||
from huggingface_hub import hf_hub_download, model_info, snapshot_download
|
||||
from huggingface_hub.utils import send_telemetry
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
@ -47,7 +49,6 @@ from ..utils import (
|
|||
BaseOutput,
|
||||
deprecate,
|
||||
get_class_from_dynamic_module,
|
||||
http_user_agent,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_safetensors_available,
|
||||
|
@ -179,8 +180,7 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
|
||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
|
@ -220,6 +220,177 @@ def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike],
|
|||
return usable_filenames, variant_filenames
|
||||
|
||||
|
||||
def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames):
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=None,
|
||||
)
|
||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
|
||||
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
|
||||
|
||||
if set(comp_model_filenames) == set(model_filenames):
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
|
||||
def maybe_raise_or_warn(
|
||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||
):
|
||||
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
)
|
||||
|
||||
|
||||
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
return class_obj, class_candidates
|
||||
|
||||
|
||||
def load_sub_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
importable_classes: List[Any],
|
||||
pipelines: Any,
|
||||
is_pipeline_module: bool,
|
||||
pipeline_class: Any,
|
||||
torch_dtype: torch.dtype,
|
||||
provider: Any,
|
||||
sess_options: Any,
|
||||
device_map: Optional[Union[Dict[str, torch.device], str]],
|
||||
model_variants: Dict[str, str],
|
||||
name: str,
|
||||
from_flax: bool,
|
||||
variant: str,
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module
|
||||
)
|
||||
|
||||
load_method_name = None
|
||||
# retrive load method name
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
# if load method name is None, then we have a dummy module -> raise Error
|
||||
if load_method_name is None:
|
||||
none_module = class_obj.__module__
|
||||
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
||||
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
||||
)
|
||||
if is_dummy_path and "dummy" in none_module:
|
||||
# call class_obj for nice error message of missing requirements
|
||||
class_obj()
|
||||
|
||||
raise ValueError(
|
||||
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
# add kwargs to loading method
|
||||
loading_kwargs = {}
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
# the following can be deleted once the minimum required `transformers` version
|
||||
# is higher than 4.27
|
||||
if (
|
||||
is_transformers_model
|
||||
and loading_kwargs["variant"] is not None
|
||||
and transformers_version < version.parse("4.27.0")
|
||||
):
|
||||
raise ImportError(
|
||||
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
||||
)
|
||||
elif is_transformers_model and loading_kwargs["variant"] is None:
|
||||
loading_kwargs.pop("variant")
|
||||
|
||||
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||
if not (from_flax and is_transformers_model):
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||
else:
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
@ -524,8 +695,6 @@ class DiffusionPipeline(ConfigMixin):
|
|||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
return_cached_folder (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||
|
@ -583,13 +752,12 @@ class DiffusionPipeline(ConfigMixin):
|
|||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
||||
variant = kwargs.pop("variant", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
config_dict = cls.load_config(
|
||||
cached_folder = cls.download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
|
@ -598,117 +766,18 @@ class DiffusionPipeline(ConfigMixin):
|
|||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||
|
||||
if not local_files_only:
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant)
|
||||
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
|
||||
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
) >= version.parse("0.17.0"):
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=None,
|
||||
)
|
||||
comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision)
|
||||
comp_model_filenames = [
|
||||
".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames
|
||||
]
|
||||
|
||||
if set(comp_model_filenames) == set(model_filenames):
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
|
||||
# also allow downloading config.jsons with the model
|
||||
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
|
||||
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
|
||||
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
|
||||
if (
|
||||
len(safetensors_variant_filenames) > 0
|
||||
and safetensors_model_filenames != safetensors_variant_filenames
|
||||
):
|
||||
logger.warn(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
|
||||
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warn(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
|
||||
else:
|
||||
# allow everything since it has to be downloaded anyways
|
||||
ignore_patterns = allow_patterns = None
|
||||
|
||||
if cls != DiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
else:
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
||||
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
||||
user_agent["custom_pipeline"] = custom_pipeline
|
||||
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
from_flax=from_flax,
|
||||
custom_pipeline=custom_pipeline,
|
||||
variant=variant,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# retrieve which subfolders should load variants
|
||||
# 2. Define which model components should load variants
|
||||
# We retrieve the information by matching whether variant
|
||||
# model checkpoints exist in the subfolders
|
||||
model_variants = {}
|
||||
if variant is not None:
|
||||
for folder in os.listdir(cached_folder):
|
||||
|
@ -718,7 +787,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
if variant_exists:
|
||||
model_variants[folder] = variant
|
||||
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
|
@ -738,7 +807,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
|
||||
# To be removed in 1.0.0
|
||||
# DEPRECATED: To be removed in 1.0.0
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
version.parse(config_dict["_diffusers_version"]).base_version
|
||||
) <= version.parse("0.5.1"):
|
||||
|
@ -757,6 +826,9 @@ class DiffusionPipeline(ConfigMixin):
|
|||
)
|
||||
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
# 4. Define expected modules given pipeline signature
|
||||
# and define non-None initialized modules (=`init_kwargs`)
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
|
@ -788,6 +860,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
" separately if you need it."
|
||||
)
|
||||
|
||||
# 5. Throw nice warnings / errors for fast accelerate loading
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||
|
@ -823,135 +896,50 @@ class DiffusionPipeline(ConfigMixin):
|
|||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
# 6. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
|
||||
# 6.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name]
|
||||
loaded_sub_model = None
|
||||
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# 6.3 Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# 1. check that passed_class_obj has correct parent class
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
maybe_raise_or_warn(
|
||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||
)
|
||||
|
||||
# set passed class object
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
elif is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
if loaded_sub_model is None:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
if load_method_name is None:
|
||||
none_module = class_obj.__module__
|
||||
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
||||
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
||||
# load sub model
|
||||
loaded_sub_model = load_sub_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
importable_classes=importable_classes,
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
pipeline_class=pipeline_class,
|
||||
torch_dtype=torch_dtype,
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
device_map=device_map,
|
||||
model_variants=model_variants,
|
||||
name=name,
|
||||
from_flax=from_flax,
|
||||
variant=variant,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
)
|
||||
if is_dummy_path and "dummy" in none_module:
|
||||
# call class_obj for nice error message of missing requirements
|
||||
class_obj()
|
||||
|
||||
raise ValueError(
|
||||
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
loading_kwargs = {}
|
||||
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
# the following can be deleted once the minimum required `transformers` version
|
||||
# is higher than 4.27
|
||||
if (
|
||||
is_transformers_model
|
||||
and loading_kwargs["variant"] is not None
|
||||
and transformers_version < version.parse("4.27.0")
|
||||
):
|
||||
raise ImportError(
|
||||
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
||||
)
|
||||
elif is_transformers_model and loading_kwargs["variant"] is None:
|
||||
loading_kwargs.pop("variant")
|
||||
|
||||
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||
if not (from_flax and is_transformers_model):
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||
else:
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
# 4. Potentially add passed objects if expected
|
||||
# 7. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
|
@ -964,13 +952,251 @@ class DiffusionPipeline(ConfigMixin):
|
|||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# 5. Instantiate the pipeline
|
||||
# 8. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
|
||||
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
||||
if return_cached_folder:
|
||||
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
|
||||
deprecate("return_cached_folder", "0.17.0", message, take_from=kwargs)
|
||||
return model, cached_folder
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||
r"""
|
||||
Download and cache a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name (`str` or `os.PathLike`, *optional*):
|
||||
Should be a string, the *repo id* of a pretrained pipeline hosted inside a model repo on
|
||||
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
|
||||
`CompVis/ldm-text2im-large-256`.
|
||||
custom_pipeline (`str`, *optional*):
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Can be either:
|
||||
|
||||
- A string, the *repo id* of a custom pipeline hosted inside a model repo on
|
||||
https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
|
||||
like `hf-internal-testing/diffusers-dummy-pipeline`.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required that the model repo has a file, called `pipeline.py` that defines the custom
|
||||
pipeline.
|
||||
|
||||
</Tip>
|
||||
|
||||
- A string, the *file name* of a community pipeline hosted on GitHub under
|
||||
https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
|
||||
match exactly the file name without `.py` located under the above link, *e.g.*
|
||||
`clip_guided_stable_diffusion`.
|
||||
|
||||
<Tip>
|
||||
|
||||
Community pipelines are always loaded from the current `main` branch of GitHub.
|
||||
|
||||
</Tip>
|
||||
|
||||
- A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required that the directory has a file, called `pipeline.py` that defines the custom
|
||||
pipeline.
|
||||
|
||||
</Tip>
|
||||
|
||||
For more information on how to load and create custom pipelines, please have a look at [Loading and
|
||||
Adding Custom
|
||||
Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
|
||||
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of
|
||||
`diffusers` when loading from GitHub):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||
`revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a
|
||||
custom pipeline from GitHub.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information. specify the folder name here.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||
models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special
|
||||
["offline-mode"](https://huggingface.co/diffusers/installation.html#notice-on-telemetry-logging) to use this
|
||||
method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
|
||||
pipeline_is_cached = False
|
||||
allow_patterns = None
|
||||
ignore_patterns = None
|
||||
|
||||
user_agent = {"pipeline_class": cls.__name__}
|
||||
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
||||
user_agent["custom_pipeline"] = custom_pipeline
|
||||
|
||||
if not local_files_only:
|
||||
info = model_info(
|
||||
pretrained_model_name,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
user_agent["pretrained_model_name"] = pretrained_model_name
|
||||
send_telemetry("pipelines", library_name="diffusers", library_version=__version__, user_agent=user_agent)
|
||||
commit_hash = info.sha
|
||||
|
||||
# try loading the config file
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name,
|
||||
cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=commit_hash,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
config_is_cached = True
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||
|
||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
# if the whole pipeline is cached we don't have to ping the Hub
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
) >= version.parse("0.17.0"):
|
||||
warn_deprecated_model_variant(
|
||||
pretrained_model_name, use_auth_token, variant, revision, model_filenames
|
||||
)
|
||||
|
||||
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
|
||||
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
|
||||
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
|
||||
if (
|
||||
len(safetensors_variant_filenames) > 0
|
||||
and safetensors_model_filenames != safetensors_variant_filenames
|
||||
):
|
||||
logger.warn(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
|
||||
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warn(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
|
||||
if config_is_cached:
|
||||
re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
|
||||
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
|
||||
|
||||
expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)]
|
||||
expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)]
|
||||
|
||||
snapshot_folder = Path(config_file).parent
|
||||
pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files)
|
||||
|
||||
if pipeline_is_cached:
|
||||
# if the pipeline is cached, we can directly return it
|
||||
# else call snapshot_download
|
||||
return snapshot_folder
|
||||
|
||||
# download all allow_patterns - ignore_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
return cached_folder
|
||||
|
||||
@staticmethod
|
||||
def _get_signature_keys(obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
|
|
|
@ -954,7 +954,7 @@ def stable_unclip_image_noising_components(
|
|||
return image_normalizer, image_noising_scheduler
|
||||
|
||||
|
||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str = None,
|
||||
image_size: int = 512,
|
||||
|
|
|
@ -136,10 +136,11 @@ class SchedulerMixin:
|
|||
</Tip>
|
||||
|
||||
"""
|
||||
config, kwargs = cls.load_config(
|
||||
config, kwargs, commit_hash = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
**kwargs,
|
||||
)
|
||||
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||
|
|
|
@ -35,7 +35,11 @@ from .constants import (
|
|||
from .deprecation_utils import deprecate
|
||||
from .doc_utils import replace_example_docstring
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .hub_utils import HF_HUB_OFFLINE, http_user_agent
|
||||
from .hub_utils import (
|
||||
HF_HUB_OFFLINE,
|
||||
extract_commit_hash,
|
||||
http_user_agent,
|
||||
)
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
@ -22,6 +23,7 @@ from typing import Dict, Optional, Union
|
|||
from uuid import uuid4
|
||||
|
||||
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
|
||||
from .. import __version__
|
||||
|
@ -132,6 +134,20 @@ def create_model_card(args, model_name):
|
|||
model_card.save(card_path)
|
||||
|
||||
|
||||
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
|
||||
"""
|
||||
Extracts the commit hash from a resolved filename toward a cache file.
|
||||
"""
|
||||
if resolved_file is None or commit_hash is not None:
|
||||
return commit_hash
|
||||
resolved_file = str(Path(resolved_file).as_posix())
|
||||
search = re.search(r"snapshots/([^/]+)/", resolved_file)
|
||||
if search is None:
|
||||
return None
|
||||
commit_hash = search.groups()[0]
|
||||
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
|
||||
|
||||
|
||||
# Old default cache path, potentially to be migrated.
|
||||
# This logic was more or less taken from `transformers`, with the following differences:
|
||||
# - Diffusers doesn't use custom environment variables to specify the cache path.
|
||||
|
@ -150,7 +166,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
|
|||
|
||||
old_cache_dir = Path(old_cache_dir).expanduser()
|
||||
new_cache_dir = Path(new_cache_dir).expanduser()
|
||||
for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob
|
||||
for old_blob_path in old_cache_dir.glob("**/blobs/*"):
|
||||
if old_blob_path.is_file() and not old_blob_path.is_symlink():
|
||||
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
|
||||
new_blob_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
@ -20,6 +20,7 @@ import unittest.mock as mock
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import torch
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
|
@ -29,6 +30,13 @@ from diffusers.utils import torch_device
|
|||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_accelerate_loading_error_message(self):
|
||||
with self.assertRaises(ValueError) as error_context:
|
||||
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
|
||||
|
@ -60,6 +68,37 @@ class ModelUtilsTest(unittest.TestCase):
|
|||
if p1.data.ne(p2.data).sum() > 0:
|
||||
assert False, "Parameters not the same!"
|
||||
|
||||
def test_one_request_upon_cached(self):
|
||||
# TODO: For some reason this test fails on MPS where no HEAD call is made.
|
||||
if torch_device == "mps":
|
||||
return
|
||||
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = False
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
|
||||
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert (
|
||||
"HEAD" == cache_requests[0] and len(cache_requests) == 1
|
||||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
def test_from_save_pretrained(self):
|
||||
|
|
|
@ -25,6 +25,7 @@ import unittest.mock as mock
|
|||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import requests_mock
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
@ -61,14 +62,44 @@ torch.backends.cuda.matmul.allow_tf32 = False
|
|||
|
||||
|
||||
class DownloadTests(unittest.TestCase):
|
||||
def test_download_only_pytorch(self):
|
||||
def test_one_request_upon_cached(self):
|
||||
# TODO: For some reason this test fails on MPS where no HEAD call is made.
|
||||
if torch_device == "mps":
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert download_requests.count("HEAD") == 16, "15 calls to files + send_telemetry"
|
||||
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
|
||||
assert (
|
||||
len(download_requests) == 33
|
||||
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert cache_requests.count("HEAD") == 1, "send_telemetry is only HEAD"
|
||||
assert cache_requests.count("GET") == 1, "model info is only GET"
|
||||
assert (
|
||||
len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||
|
||||
def test_download_only_pytorch(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a flax file even if we have some here:
|
||||
|
@ -101,13 +132,13 @@ class DownloadTests(unittest.TestCase):
|
|||
def test_download_safetensors(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
|
||||
safety_checker=None,
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a pytorch file even if we have some here:
|
||||
|
@ -204,12 +235,10 @@ class DownloadTests(unittest.TestCase):
|
|||
|
||||
other_format = ".bin" if safe_avail else ".safetensors"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
|
||||
)
|
||||
all_root_files = [
|
||||
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
|
||||
]
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a variant file even if we have some here:
|
||||
|
@ -232,12 +261,10 @@ class DownloadTests(unittest.TestCase):
|
|||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
||||
)
|
||||
all_root_files = [
|
||||
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
|
||||
]
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a non-variant file even if we have some here:
|
||||
|
@ -262,14 +289,13 @@ class DownloadTests(unittest.TestCase):
|
|||
variant = "no_ema"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
||||
)
|
||||
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
|
||||
all_root_files = [t[-1] for t in os.walk(snapshots)]
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet"))
|
||||
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
|
||||
|
||||
# Some of the downloaded files should be a non-variant file, check:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||
|
@ -292,7 +318,7 @@ class DownloadTests(unittest.TestCase):
|
|||
for variant in [None, "no_ema"]:
|
||||
with self.assertRaises(OSError) as error_context:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/stable-diffusion-broken-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
|
@ -302,13 +328,11 @@ class DownloadTests(unittest.TestCase):
|
|||
|
||||
# text encoder has fp16 variants so we can load it
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
|
||||
)
|
||||
assert pipe is not None
|
||||
|
||||
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
|
||||
all_root_files = [t[-1] for t in os.walk(snapshots)]
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a non-variant file even if we have some here:
|
||||
|
@ -395,7 +419,7 @@ class CustomPipelineTests(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_load_pipeline_from_git(self):
|
||||
def test_download_from_git(self):
|
||||
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
||||
|
|
Loading…
Reference in New Issue