[From pretrained] Speed-up loading from cache (#2515)
* [From pretrained] Speed-up loading from cache * up * Fix more * fix one more bug * make style * bigger refactor * factor out function * Improve more * better * deprecate return cache folder * clean up * improve tests * up * upload * add nice tests * simplify * finish * correct * fix version * rename * Apply suggestions from code review Co-authored-by: Lucain <lucainp@gmail.com> * rename * correct doc string * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * apply code suggestions * finish --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
7fe638c502
commit
d761b58bfc
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -125,7 +125,7 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipe = download_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path=args.checkpoint_path,
|
checkpoint_path=args.checkpoint_path,
|
||||||
original_config_file=args.original_config_file,
|
original_config_file=args.original_config_file,
|
||||||
image_size=args.image_size,
|
image_size=args.image_size,
|
||||||
|
|
4
setup.py
4
setup.py
|
@ -86,7 +86,8 @@ _deps = [
|
||||||
"filelock",
|
"filelock",
|
||||||
"flax>=0.4.1",
|
"flax>=0.4.1",
|
||||||
"hf-doc-builder>=0.3.0",
|
"hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub>=0.10.0",
|
"huggingface-hub>=0.13.0",
|
||||||
|
"requests-mock==1.10.0",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
"jax>=0.2.8,!=0.3.2",
|
"jax>=0.2.8,!=0.3.2",
|
||||||
|
@ -192,6 +193,7 @@ extras["test"] = deps_list(
|
||||||
"pytest",
|
"pytest",
|
||||||
"pytest-timeout",
|
"pytest-timeout",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
|
"requests-mock",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"scipy",
|
"scipy",
|
||||||
|
|
|
@ -31,7 +31,15 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
|
from .utils import (
|
||||||
|
DIFFUSERS_CACHE,
|
||||||
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
|
DummyObject,
|
||||||
|
deprecate,
|
||||||
|
extract_commit_hash,
|
||||||
|
http_user_agent,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -231,7 +239,11 @@ class ConfigMixin:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(
|
def load_config(
|
||||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
cls,
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
|
return_unused_kwargs=False,
|
||||||
|
return_commit_hash=False,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
r"""
|
r"""
|
||||||
Instantiate a Python class from a config dictionary
|
Instantiate a Python class from a config dictionary
|
||||||
|
@ -271,6 +283,10 @@ class ConfigMixin:
|
||||||
subfolder (`str`, *optional*, defaults to `""`):
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||||
|
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
||||||
|
Whether unused keyword arguments of the config shall be returned.
|
||||||
|
return_commit_hash (`bool`, *optional*, defaults to `False):
|
||||||
|
Whether the commit_hash of the loaded configuration shall be returned.
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
|
@ -295,8 +311,10 @@ class ConfigMixin:
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
_ = kwargs.pop("mirror", None)
|
_ = kwargs.pop("mirror", None)
|
||||||
subfolder = kwargs.pop("subfolder", None)
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
|
user_agent = kwargs.pop("user_agent", {})
|
||||||
|
|
||||||
user_agent = {"file_type": "config"}
|
user_agent = {**user_agent, "file_type": "config"}
|
||||||
|
user_agent = http_user_agent(user_agent)
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
|
||||||
|
@ -336,7 +354,6 @@ class ConfigMixin:
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
except RepositoryNotFoundError:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
||||||
|
@ -378,13 +395,23 @@ class ConfigMixin:
|
||||||
try:
|
try:
|
||||||
# Load config dict
|
# Load config dict
|
||||||
config_dict = cls._dict_from_json_file(config_file)
|
config_dict = cls._dict_from_json_file(config_file)
|
||||||
|
|
||||||
|
commit_hash = extract_commit_hash(config_file)
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
||||||
|
|
||||||
if return_unused_kwargs:
|
if not (return_unused_kwargs or return_commit_hash):
|
||||||
return config_dict, kwargs
|
return config_dict
|
||||||
|
|
||||||
return config_dict
|
outputs = (config_dict,)
|
||||||
|
|
||||||
|
if return_unused_kwargs:
|
||||||
|
outputs += (kwargs,)
|
||||||
|
|
||||||
|
if return_commit_hash:
|
||||||
|
outputs += (commit_hash,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_init_keys(cls):
|
def _get_init_keys(cls):
|
||||||
|
|
|
@ -10,7 +10,8 @@ deps = {
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
"flax": "flax>=0.4.1",
|
"flax": "flax>=0.4.1",
|
||||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub": "huggingface-hub>=0.10.0",
|
"huggingface-hub": "huggingface-hub>=0.13.0",
|
||||||
|
"requests-mock": "requests-mock==1.10.0",
|
||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
"jax": "jax>=0.2.8,!=0.3.2",
|
"jax": "jax>=0.2.8,!=0.3.2",
|
||||||
|
|
|
@ -458,18 +458,34 @@ class ModelMixin(torch.nn.Module):
|
||||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load config if we don't provide a configuration
|
||||||
|
config_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
user_agent = {
|
user_agent = {
|
||||||
"diffusers": __version__,
|
"diffusers": __version__,
|
||||||
"file_type": "model",
|
"file_type": "model",
|
||||||
"framework": "pytorch",
|
"framework": "pytorch",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# load config
|
||||||
config_path = pretrained_model_name_or_path
|
config, unused_kwargs, commit_hash = cls.load_config(
|
||||||
|
config_path,
|
||||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
cache_dir=cache_dir,
|
||||||
# Load model
|
return_unused_kwargs=True,
|
||||||
|
return_commit_hash=True,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
device_map=device_map,
|
||||||
|
user_agent=user_agent,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load model
|
||||||
model_file = None
|
model_file = None
|
||||||
if from_flax:
|
if from_flax:
|
||||||
model_file = _get_model_file(
|
model_file = _get_model_file(
|
||||||
|
@ -484,20 +500,7 @@ class ModelMixin(torch.nn.Module):
|
||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
commit_hash=commit_hash,
|
||||||
config, unused_kwargs = cls.load_config(
|
|
||||||
config_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
return_unused_kwargs=True,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
device_map=device_map,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
|
@ -520,6 +523,7 @@ class ModelMixin(torch.nn.Module):
|
||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
except: # noqa: E722
|
except: # noqa: E722
|
||||||
pass
|
pass
|
||||||
|
@ -536,25 +540,12 @@ class ModelMixin(torch.nn.Module):
|
||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
# Instantiate model with empty weights
|
# Instantiate model with empty weights
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
config, unused_kwargs = cls.load_config(
|
|
||||||
config_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
return_unused_kwargs=True,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
device_map=device_map,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||||
|
@ -593,20 +584,6 @@ class ModelMixin(torch.nn.Module):
|
||||||
"error_msgs": [],
|
"error_msgs": [],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config, unused_kwargs = cls.load_config(
|
|
||||||
config_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
return_unused_kwargs=True,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
device_map=device_map,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
state_dict = load_state_dict(model_file, variant=variant)
|
state_dict = load_state_dict(model_file, variant=variant)
|
||||||
|
@ -803,6 +780,7 @@ def _get_model_file(
|
||||||
use_auth_token,
|
use_auth_token,
|
||||||
user_agent,
|
user_agent,
|
||||||
revision,
|
revision,
|
||||||
|
commit_hash=None,
|
||||||
):
|
):
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isfile(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path):
|
||||||
|
@ -840,7 +818,7 @@ def _get_model_file(
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
revision=revision,
|
revision=revision or commit_hash,
|
||||||
)
|
)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||||
|
@ -865,7 +843,7 @@ def _get_model_file(
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
revision=revision,
|
revision=revision or commit_hash,
|
||||||
)
|
)
|
||||||
return model_file
|
return model_file
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
@ -26,7 +27,8 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import model_info, snapshot_download
|
from huggingface_hub import hf_hub_download, model_info, snapshot_download
|
||||||
|
from huggingface_hub.utils import send_telemetry
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
@ -47,7 +49,6 @@ from ..utils import (
|
||||||
BaseOutput,
|
BaseOutput,
|
||||||
deprecate,
|
deprecate,
|
||||||
get_class_from_dynamic_module,
|
get_class_from_dynamic_module,
|
||||||
http_user_agent,
|
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_accelerate_version,
|
is_accelerate_version,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
|
@ -179,8 +180,7 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
|
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
||||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
|
||||||
weight_names = [
|
weight_names = [
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
SAFETENSORS_WEIGHTS_NAME,
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
|
@ -220,6 +220,177 @@ def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike],
|
||||||
return usable_filenames, variant_filenames
|
return usable_filenames, variant_filenames
|
||||||
|
|
||||||
|
|
||||||
|
def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames):
|
||||||
|
info = model_info(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||||
|
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
|
||||||
|
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
|
||||||
|
|
||||||
|
if set(comp_model_filenames) == set(model_filenames):
|
||||||
|
warnings.warn(
|
||||||
|
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_raise_or_warn(
|
||||||
|
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||||
|
):
|
||||||
|
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
||||||
|
if not is_pipeline_module:
|
||||||
|
library = importlib.import_module(library_name)
|
||||||
|
class_obj = getattr(library, class_name)
|
||||||
|
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||||
|
|
||||||
|
expected_class_obj = None
|
||||||
|
for class_name, class_candidate in class_candidates.items():
|
||||||
|
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||||
|
expected_class_obj = class_candidate
|
||||||
|
|
||||||
|
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||||
|
raise ValueError(
|
||||||
|
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||||
|
f" {expected_class_obj}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||||
|
" has the correct type"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
|
||||||
|
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||||
|
if is_pipeline_module:
|
||||||
|
pipeline_module = getattr(pipelines, library_name)
|
||||||
|
|
||||||
|
class_obj = getattr(pipeline_module, class_name)
|
||||||
|
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||||
|
else:
|
||||||
|
# else we just import it from the library.
|
||||||
|
library = importlib.import_module(library_name)
|
||||||
|
|
||||||
|
class_obj = getattr(library, class_name)
|
||||||
|
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||||
|
|
||||||
|
return class_obj, class_candidates
|
||||||
|
|
||||||
|
|
||||||
|
def load_sub_model(
|
||||||
|
library_name: str,
|
||||||
|
class_name: str,
|
||||||
|
importable_classes: List[Any],
|
||||||
|
pipelines: Any,
|
||||||
|
is_pipeline_module: bool,
|
||||||
|
pipeline_class: Any,
|
||||||
|
torch_dtype: torch.dtype,
|
||||||
|
provider: Any,
|
||||||
|
sess_options: Any,
|
||||||
|
device_map: Optional[Union[Dict[str, torch.device], str]],
|
||||||
|
model_variants: Dict[str, str],
|
||||||
|
name: str,
|
||||||
|
from_flax: bool,
|
||||||
|
variant: str,
|
||||||
|
low_cpu_mem_usage: bool,
|
||||||
|
cached_folder: Union[str, os.PathLike],
|
||||||
|
):
|
||||||
|
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||||
|
# retrieve class candidates
|
||||||
|
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||||
|
library_name, class_name, importable_classes, pipelines, is_pipeline_module
|
||||||
|
)
|
||||||
|
|
||||||
|
load_method_name = None
|
||||||
|
# retrive load method name
|
||||||
|
for class_name, class_candidate in class_candidates.items():
|
||||||
|
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||||
|
load_method_name = importable_classes[class_name][1]
|
||||||
|
|
||||||
|
# if load method name is None, then we have a dummy module -> raise Error
|
||||||
|
if load_method_name is None:
|
||||||
|
none_module = class_obj.__module__
|
||||||
|
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
||||||
|
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
||||||
|
)
|
||||||
|
if is_dummy_path and "dummy" in none_module:
|
||||||
|
# call class_obj for nice error message of missing requirements
|
||||||
|
class_obj()
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
||||||
|
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||||
|
)
|
||||||
|
|
||||||
|
load_method = getattr(class_obj, load_method_name)
|
||||||
|
|
||||||
|
# add kwargs to loading method
|
||||||
|
loading_kwargs = {}
|
||||||
|
if issubclass(class_obj, torch.nn.Module):
|
||||||
|
loading_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
||||||
|
loading_kwargs["provider"] = provider
|
||||||
|
loading_kwargs["sess_options"] = sess_options
|
||||||
|
|
||||||
|
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||||
|
else:
|
||||||
|
transformers_version = "N/A"
|
||||||
|
|
||||||
|
is_transformers_model = (
|
||||||
|
is_transformers_available()
|
||||||
|
and issubclass(class_obj, PreTrainedModel)
|
||||||
|
and transformers_version >= version.parse("4.20.0")
|
||||||
|
)
|
||||||
|
|
||||||
|
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||||
|
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||||
|
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||||
|
if is_diffusers_model or is_transformers_model:
|
||||||
|
loading_kwargs["device_map"] = device_map
|
||||||
|
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||||
|
if from_flax:
|
||||||
|
loading_kwargs["from_flax"] = True
|
||||||
|
|
||||||
|
# the following can be deleted once the minimum required `transformers` version
|
||||||
|
# is higher than 4.27
|
||||||
|
if (
|
||||||
|
is_transformers_model
|
||||||
|
and loading_kwargs["variant"] is not None
|
||||||
|
and transformers_version < version.parse("4.27.0")
|
||||||
|
):
|
||||||
|
raise ImportError(
|
||||||
|
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
||||||
|
)
|
||||||
|
elif is_transformers_model and loading_kwargs["variant"] is None:
|
||||||
|
loading_kwargs.pop("variant")
|
||||||
|
|
||||||
|
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||||
|
if not (from_flax and is_transformers_model):
|
||||||
|
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||||
|
else:
|
||||||
|
loading_kwargs["low_cpu_mem_usage"] = False
|
||||||
|
|
||||||
|
# check if the module is in a subdirectory
|
||||||
|
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||||
|
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||||
|
else:
|
||||||
|
# else load from the root directory
|
||||||
|
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||||
|
|
||||||
|
return loaded_sub_model
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPipeline(ConfigMixin):
|
class DiffusionPipeline(ConfigMixin):
|
||||||
r"""
|
r"""
|
||||||
Base class for all models.
|
Base class for all models.
|
||||||
|
@ -524,8 +695,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||||
setting this argument to `True` will raise an error.
|
setting this argument to `True` will raise an error.
|
||||||
return_cached_folder (`bool`, *optional*, defaults to `False`):
|
|
||||||
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
|
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||||
|
@ -583,13 +752,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
sess_options = kwargs.pop("sess_options", None)
|
sess_options = kwargs.pop("sess_options", None)
|
||||||
device_map = kwargs.pop("device_map", None)
|
device_map = kwargs.pop("device_map", None)
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
|
||||||
variant = kwargs.pop("variant", None)
|
variant = kwargs.pop("variant", None)
|
||||||
|
|
||||||
# 1. Download the checkpoints and configs
|
# 1. Download the checkpoints and configs
|
||||||
# use snapshot download here to get it working from from_pretrained
|
# use snapshot download here to get it working from from_pretrained
|
||||||
if not os.path.isdir(pretrained_model_name_or_path):
|
if not os.path.isdir(pretrained_model_name_or_path):
|
||||||
config_dict = cls.load_config(
|
cached_folder = cls.download(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
|
@ -598,117 +766,18 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
)
|
from_flax=from_flax,
|
||||||
|
custom_pipeline=custom_pipeline,
|
||||||
# retrieve all folder_names that contain relevant files
|
variant=variant,
|
||||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
|
||||||
|
|
||||||
if not local_files_only:
|
|
||||||
info = model_info(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant)
|
|
||||||
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
|
|
||||||
|
|
||||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
|
||||||
version.parse(__version__).base_version
|
|
||||||
) >= version.parse("0.17.0"):
|
|
||||||
info = model_info(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=None,
|
|
||||||
)
|
|
||||||
comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision)
|
|
||||||
comp_model_filenames = [
|
|
||||||
".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames
|
|
||||||
]
|
|
||||||
|
|
||||||
if set(comp_model_filenames) == set(model_filenames):
|
|
||||||
warnings.warn(
|
|
||||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
# all filenames compatible with variant will be added
|
|
||||||
allow_patterns = list(model_filenames)
|
|
||||||
|
|
||||||
# allow all patterns from non-model folders
|
|
||||||
# this enables downloading schedulers, tokenizers, ...
|
|
||||||
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
|
|
||||||
# also allow downloading config.jsons with the model
|
|
||||||
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
|
|
||||||
|
|
||||||
allow_patterns += [
|
|
||||||
SCHEDULER_CONFIG_NAME,
|
|
||||||
CONFIG_NAME,
|
|
||||||
cls.config_name,
|
|
||||||
CUSTOM_PIPELINE_FILE_NAME,
|
|
||||||
]
|
|
||||||
|
|
||||||
if from_flax:
|
|
||||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
|
||||||
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
|
|
||||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
|
||||||
|
|
||||||
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
|
|
||||||
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
|
|
||||||
if (
|
|
||||||
len(safetensors_variant_filenames) > 0
|
|
||||||
and safetensors_model_filenames != safetensors_variant_filenames
|
|
||||||
):
|
|
||||||
logger.warn(
|
|
||||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
|
||||||
|
|
||||||
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
|
|
||||||
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
|
|
||||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
|
||||||
logger.warn(
|
|
||||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# allow everything since it has to be downloaded anyways
|
|
||||||
ignore_patterns = allow_patterns = None
|
|
||||||
|
|
||||||
if cls != DiffusionPipeline:
|
|
||||||
requested_pipeline_class = cls.__name__
|
|
||||||
else:
|
|
||||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
|
||||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
|
||||||
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
|
||||||
user_agent["custom_pipeline"] = custom_pipeline
|
|
||||||
|
|
||||||
user_agent = http_user_agent(user_agent)
|
|
||||||
|
|
||||||
# download all allow_patterns
|
|
||||||
cached_folder = snapshot_download(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
ignore_patterns=ignore_patterns,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cached_folder = pretrained_model_name_or_path
|
cached_folder = pretrained_model_name_or_path
|
||||||
config_dict = cls.load_config(cached_folder)
|
|
||||||
|
|
||||||
# retrieve which subfolders should load variants
|
config_dict = cls.load_config(cached_folder)
|
||||||
|
|
||||||
|
# 2. Define which model components should load variants
|
||||||
|
# We retrieve the information by matching whether variant
|
||||||
|
# model checkpoints exist in the subfolders
|
||||||
model_variants = {}
|
model_variants = {}
|
||||||
if variant is not None:
|
if variant is not None:
|
||||||
for folder in os.listdir(cached_folder):
|
for folder in os.listdir(cached_folder):
|
||||||
|
@ -718,7 +787,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
if variant_exists:
|
if variant_exists:
|
||||||
model_variants[folder] = variant
|
model_variants[folder] = variant
|
||||||
|
|
||||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||||
# if we load from explicit class, let's use it
|
# if we load from explicit class, let's use it
|
||||||
if custom_pipeline is not None:
|
if custom_pipeline is not None:
|
||||||
if custom_pipeline.endswith(".py"):
|
if custom_pipeline.endswith(".py"):
|
||||||
|
@ -738,7 +807,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||||
|
|
||||||
# To be removed in 1.0.0
|
# DEPRECATED: To be removed in 1.0.0
|
||||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||||
version.parse(config_dict["_diffusers_version"]).base_version
|
version.parse(config_dict["_diffusers_version"]).base_version
|
||||||
) <= version.parse("0.5.1"):
|
) <= version.parse("0.5.1"):
|
||||||
|
@ -757,6 +826,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
)
|
)
|
||||||
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
|
||||||
|
# 4. Define expected modules given pipeline signature
|
||||||
|
# and define non-None initialized modules (=`init_kwargs`)
|
||||||
|
|
||||||
# some modules can be passed directly to the init
|
# some modules can be passed directly to the init
|
||||||
# in this case they are already instantiated in `kwargs`
|
# in this case they are already instantiated in `kwargs`
|
||||||
# extract them here
|
# extract them here
|
||||||
|
@ -788,6 +860,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
" separately if you need it."
|
" separately if you need it."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 5. Throw nice warnings / errors for fast accelerate loading
|
||||||
if len(unused_kwargs) > 0:
|
if len(unused_kwargs) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||||
|
@ -823,135 +896,50 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
# import it here to avoid circular import
|
# import it here to avoid circular import
|
||||||
from diffusers import pipelines
|
from diffusers import pipelines
|
||||||
|
|
||||||
# 3. Load each module in the pipeline
|
# 6. Load each module in the pipeline
|
||||||
for name, (library_name, class_name) in init_dict.items():
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||||
if class_name.startswith("Flax"):
|
if class_name.startswith("Flax"):
|
||||||
class_name = class_name[4:]
|
class_name = class_name[4:]
|
||||||
|
|
||||||
|
# 6.2 Define all importable classes
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
|
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name]
|
||||||
loaded_sub_model = None
|
loaded_sub_model = None
|
||||||
|
|
||||||
# if the model is in a pipeline module, then we load it from the pipeline
|
# 6.3 Use passed sub model or load class_name from library_name
|
||||||
if name in passed_class_obj:
|
if name in passed_class_obj:
|
||||||
# 1. check that passed_class_obj has correct parent class
|
# if the model is in a pipeline module, then we load it from the pipeline
|
||||||
if not is_pipeline_module:
|
# check that passed_class_obj has correct parent class
|
||||||
library = importlib.import_module(library_name)
|
maybe_raise_or_warn(
|
||||||
class_obj = getattr(library, class_name)
|
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
|
||||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
|
||||||
|
|
||||||
expected_class_obj = None
|
|
||||||
for class_name, class_candidate in class_candidates.items():
|
|
||||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
|
||||||
expected_class_obj = class_candidate
|
|
||||||
|
|
||||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
|
||||||
raise ValueError(
|
|
||||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
|
||||||
f" {expected_class_obj}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
|
||||||
" has the correct type"
|
|
||||||
)
|
|
||||||
|
|
||||||
# set passed class object
|
|
||||||
loaded_sub_model = passed_class_obj[name]
|
|
||||||
elif is_pipeline_module:
|
|
||||||
pipeline_module = getattr(pipelines, library_name)
|
|
||||||
class_obj = getattr(pipeline_module, class_name)
|
|
||||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
|
||||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
|
||||||
else:
|
|
||||||
# else we just import it from the library.
|
|
||||||
library = importlib.import_module(library_name)
|
|
||||||
|
|
||||||
class_obj = getattr(library, class_name)
|
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
|
||||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
|
||||||
|
|
||||||
if loaded_sub_model is None:
|
|
||||||
load_method_name = None
|
|
||||||
for class_name, class_candidate in class_candidates.items():
|
|
||||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
|
||||||
load_method_name = importable_classes[class_name][1]
|
|
||||||
|
|
||||||
if load_method_name is None:
|
|
||||||
none_module = class_obj.__module__
|
|
||||||
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
|
||||||
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
|
||||||
)
|
|
||||||
if is_dummy_path and "dummy" in none_module:
|
|
||||||
# call class_obj for nice error message of missing requirements
|
|
||||||
class_obj()
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
|
||||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
|
||||||
)
|
|
||||||
|
|
||||||
load_method = getattr(class_obj, load_method_name)
|
|
||||||
loading_kwargs = {}
|
|
||||||
|
|
||||||
if issubclass(class_obj, torch.nn.Module):
|
|
||||||
loading_kwargs["torch_dtype"] = torch_dtype
|
|
||||||
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
|
||||||
loading_kwargs["provider"] = provider
|
|
||||||
loading_kwargs["sess_options"] = sess_options
|
|
||||||
|
|
||||||
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
|
||||||
|
|
||||||
if is_transformers_available():
|
|
||||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
|
||||||
else:
|
|
||||||
transformers_version = "N/A"
|
|
||||||
|
|
||||||
is_transformers_model = (
|
|
||||||
is_transformers_available()
|
|
||||||
and issubclass(class_obj, PreTrainedModel)
|
|
||||||
and transformers_version >= version.parse("4.20.0")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
loaded_sub_model = passed_class_obj[name]
|
||||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
else:
|
||||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
# load sub model
|
||||||
if is_diffusers_model or is_transformers_model:
|
loaded_sub_model = load_sub_model(
|
||||||
loading_kwargs["device_map"] = device_map
|
library_name=library_name,
|
||||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
class_name=class_name,
|
||||||
if from_flax:
|
importable_classes=importable_classes,
|
||||||
loading_kwargs["from_flax"] = True
|
pipelines=pipelines,
|
||||||
|
is_pipeline_module=is_pipeline_module,
|
||||||
# the following can be deleted once the minimum required `transformers` version
|
pipeline_class=pipeline_class,
|
||||||
# is higher than 4.27
|
torch_dtype=torch_dtype,
|
||||||
if (
|
provider=provider,
|
||||||
is_transformers_model
|
sess_options=sess_options,
|
||||||
and loading_kwargs["variant"] is not None
|
device_map=device_map,
|
||||||
and transformers_version < version.parse("4.27.0")
|
model_variants=model_variants,
|
||||||
):
|
name=name,
|
||||||
raise ImportError(
|
from_flax=from_flax,
|
||||||
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
variant=variant,
|
||||||
)
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
elif is_transformers_model and loading_kwargs["variant"] is None:
|
cached_folder=cached_folder,
|
||||||
loading_kwargs.pop("variant")
|
)
|
||||||
|
|
||||||
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
|
||||||
if not (from_flax and is_transformers_model):
|
|
||||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
|
||||||
else:
|
|
||||||
loading_kwargs["low_cpu_mem_usage"] = False
|
|
||||||
|
|
||||||
# check if the module is in a subdirectory
|
|
||||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
|
||||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
|
||||||
else:
|
|
||||||
# else load from the root directory
|
|
||||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
|
||||||
|
|
||||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||||
|
|
||||||
# 4. Potentially add passed objects if expected
|
# 7. Potentially add passed objects if expected
|
||||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||||
passed_modules = list(passed_class_obj.keys())
|
passed_modules = list(passed_class_obj.keys())
|
||||||
optional_modules = pipeline_class._optional_components
|
optional_modules = pipeline_class._optional_components
|
||||||
|
@ -964,13 +952,251 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Instantiate the pipeline
|
# 8. Instantiate the pipeline
|
||||||
model = pipeline_class(**init_kwargs)
|
model = pipeline_class(**init_kwargs)
|
||||||
|
|
||||||
|
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
||||||
if return_cached_folder:
|
if return_cached_folder:
|
||||||
|
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
|
||||||
|
deprecate("return_cached_folder", "0.17.0", message, take_from=kwargs)
|
||||||
return model, cached_folder
|
return model, cached_folder
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||||
|
r"""
|
||||||
|
Download and cache a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_name (`str` or `os.PathLike`, *optional*):
|
||||||
|
Should be a string, the *repo id* of a pretrained pipeline hosted inside a model repo on
|
||||||
|
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
|
||||||
|
`CompVis/ldm-text2im-large-256`.
|
||||||
|
custom_pipeline (`str`, *optional*):
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This is an experimental feature and is likely to change in the future.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the *repo id* of a custom pipeline hosted inside a model repo on
|
||||||
|
https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
|
||||||
|
like `hf-internal-testing/diffusers-dummy-pipeline`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It is required that the model repo has a file, called `pipeline.py` that defines the custom
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
- A string, the *file name* of a community pipeline hosted on GitHub under
|
||||||
|
https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
|
||||||
|
match exactly the file name without `.py` located under the above link, *e.g.*
|
||||||
|
`clip_guided_stable_diffusion`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Community pipelines are always loaded from the current `main` branch of GitHub.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
- A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It is required that the directory has a file, called `pipeline.py` that defines the custom
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
For more information on how to load and create custom pipelines, please have a look at [Loading and
|
||||||
|
Adding Custom
|
||||||
|
Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
|
||||||
|
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
resume_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||||
|
file exists.
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||||
|
use_auth_token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||||
|
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of
|
||||||
|
`diffusers` when loading from GitHub):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||||
|
`revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a
|
||||||
|
custom pipeline from GitHub.
|
||||||
|
mirror (`str`, *optional*):
|
||||||
|
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||||
|
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||||
|
Please refer to the mirror site for more information. specify the folder name here.
|
||||||
|
variant (`str`, *optional*):
|
||||||
|
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||||
|
ignored when using `from_flax`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||||
|
models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Activate the special
|
||||||
|
["offline-mode"](https://huggingface.co/diffusers/installation.html#notice-on-telemetry-logging) to use this
|
||||||
|
method in a firewalled environment.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
from_flax = kwargs.pop("from_flax", False)
|
||||||
|
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||||
|
variant = kwargs.pop("variant", None)
|
||||||
|
|
||||||
|
pipeline_is_cached = False
|
||||||
|
allow_patterns = None
|
||||||
|
ignore_patterns = None
|
||||||
|
|
||||||
|
user_agent = {"pipeline_class": cls.__name__}
|
||||||
|
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
||||||
|
user_agent["custom_pipeline"] = custom_pipeline
|
||||||
|
|
||||||
|
if not local_files_only:
|
||||||
|
info = model_info(
|
||||||
|
pretrained_model_name,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
user_agent["pretrained_model_name"] = pretrained_model_name
|
||||||
|
send_telemetry("pipelines", library_name="diffusers", library_version=__version__, user_agent=user_agent)
|
||||||
|
commit_hash = info.sha
|
||||||
|
|
||||||
|
# try loading the config file
|
||||||
|
config_file = hf_hub_download(
|
||||||
|
pretrained_model_name,
|
||||||
|
cls.config_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
revision=commit_hash,
|
||||||
|
proxies=proxies,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
config_dict = cls._dict_from_json_file(config_file)
|
||||||
|
config_is_cached = True
|
||||||
|
|
||||||
|
# retrieve all folder_names that contain relevant files
|
||||||
|
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||||
|
|
||||||
|
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||||
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||||
|
|
||||||
|
# if the whole pipeline is cached we don't have to ping the Hub
|
||||||
|
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||||
|
version.parse(__version__).base_version
|
||||||
|
) >= version.parse("0.17.0"):
|
||||||
|
warn_deprecated_model_variant(
|
||||||
|
pretrained_model_name, use_auth_token, variant, revision, model_filenames
|
||||||
|
)
|
||||||
|
|
||||||
|
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
|
||||||
|
|
||||||
|
# all filenames compatible with variant will be added
|
||||||
|
allow_patterns = list(model_filenames)
|
||||||
|
|
||||||
|
# allow all patterns from non-model folders
|
||||||
|
# this enables downloading schedulers, tokenizers, ...
|
||||||
|
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
|
||||||
|
# also allow downloading config.json files with the model
|
||||||
|
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
|
||||||
|
|
||||||
|
allow_patterns += [
|
||||||
|
SCHEDULER_CONFIG_NAME,
|
||||||
|
CONFIG_NAME,
|
||||||
|
cls.config_name,
|
||||||
|
CUSTOM_PIPELINE_FILE_NAME,
|
||||||
|
]
|
||||||
|
|
||||||
|
if from_flax:
|
||||||
|
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||||
|
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
|
||||||
|
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||||
|
|
||||||
|
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
|
||||||
|
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
|
||||||
|
if (
|
||||||
|
len(safetensors_variant_filenames) > 0
|
||||||
|
and safetensors_model_filenames != safetensors_variant_filenames
|
||||||
|
):
|
||||||
|
logger.warn(
|
||||||
|
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||||
|
|
||||||
|
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
|
||||||
|
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
|
||||||
|
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||||
|
logger.warn(
|
||||||
|
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||||
|
)
|
||||||
|
|
||||||
|
if config_is_cached:
|
||||||
|
re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
|
||||||
|
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
|
||||||
|
|
||||||
|
expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)]
|
||||||
|
expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)]
|
||||||
|
|
||||||
|
snapshot_folder = Path(config_file).parent
|
||||||
|
pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files)
|
||||||
|
|
||||||
|
if pipeline_is_cached:
|
||||||
|
# if the pipeline is cached, we can directly return it
|
||||||
|
# else call snapshot_download
|
||||||
|
return snapshot_folder
|
||||||
|
|
||||||
|
# download all allow_patterns - ignore_patterns
|
||||||
|
cached_folder = snapshot_download(
|
||||||
|
pretrained_model_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cached_folder
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_signature_keys(obj):
|
def _get_signature_keys(obj):
|
||||||
parameters = inspect.signature(obj.__init__).parameters
|
parameters = inspect.signature(obj.__init__).parameters
|
||||||
|
|
|
@ -954,7 +954,7 @@ def stable_unclip_image_noising_components(
|
||||||
return image_normalizer, image_noising_scheduler
|
return image_normalizer, image_noising_scheduler
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
def download_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
original_config_file: str = None,
|
original_config_file: str = None,
|
||||||
image_size: int = 512,
|
image_size: int = 512,
|
||||||
|
|
|
@ -136,10 +136,11 @@ class SchedulerMixin:
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config, kwargs = cls.load_config(
|
config, kwargs, commit_hash = cls.load_config(
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
return_unused_kwargs=True,
|
return_unused_kwargs=True,
|
||||||
|
return_commit_hash=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||||
|
|
|
@ -35,7 +35,11 @@ from .constants import (
|
||||||
from .deprecation_utils import deprecate
|
from .deprecation_utils import deprecate
|
||||||
from .doc_utils import replace_example_docstring
|
from .doc_utils import replace_example_docstring
|
||||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||||
from .hub_utils import HF_HUB_OFFLINE, http_user_agent
|
from .hub_utils import (
|
||||||
|
HF_HUB_OFFLINE,
|
||||||
|
extract_commit_hash,
|
||||||
|
http_user_agent,
|
||||||
|
)
|
||||||
from .import_utils import (
|
from .import_utils import (
|
||||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||||
ENV_VARS_TRUE_VALUES,
|
ENV_VARS_TRUE_VALUES,
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -22,6 +23,7 @@ from typing import Dict, Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
|
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
|
||||||
|
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||||
from huggingface_hub.utils import is_jinja_available
|
from huggingface_hub.utils import is_jinja_available
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
|
@ -132,6 +134,20 @@ def create_model_card(args, model_name):
|
||||||
model_card.save(card_path)
|
model_card.save(card_path)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Extracts the commit hash from a resolved filename toward a cache file.
|
||||||
|
"""
|
||||||
|
if resolved_file is None or commit_hash is not None:
|
||||||
|
return commit_hash
|
||||||
|
resolved_file = str(Path(resolved_file).as_posix())
|
||||||
|
search = re.search(r"snapshots/([^/]+)/", resolved_file)
|
||||||
|
if search is None:
|
||||||
|
return None
|
||||||
|
commit_hash = search.groups()[0]
|
||||||
|
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
|
||||||
|
|
||||||
|
|
||||||
# Old default cache path, potentially to be migrated.
|
# Old default cache path, potentially to be migrated.
|
||||||
# This logic was more or less taken from `transformers`, with the following differences:
|
# This logic was more or less taken from `transformers`, with the following differences:
|
||||||
# - Diffusers doesn't use custom environment variables to specify the cache path.
|
# - Diffusers doesn't use custom environment variables to specify the cache path.
|
||||||
|
@ -150,7 +166,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
|
||||||
|
|
||||||
old_cache_dir = Path(old_cache_dir).expanduser()
|
old_cache_dir = Path(old_cache_dir).expanduser()
|
||||||
new_cache_dir = Path(new_cache_dir).expanduser()
|
new_cache_dir = Path(new_cache_dir).expanduser()
|
||||||
for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob
|
for old_blob_path in old_cache_dir.glob("**/blobs/*"):
|
||||||
if old_blob_path.is_file() and not old_blob_path.is_symlink():
|
if old_blob_path.is_file() and not old_blob_path.is_symlink():
|
||||||
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
|
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
|
||||||
new_blob_path.parent.mkdir(parents=True, exist_ok=True)
|
new_blob_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
@ -20,6 +20,7 @@ import unittest.mock as mock
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests_mock
|
||||||
import torch
|
import torch
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
@ -29,6 +30,13 @@ from diffusers.utils import torch_device
|
||||||
|
|
||||||
|
|
||||||
class ModelUtilsTest(unittest.TestCase):
|
class ModelUtilsTest(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
|
||||||
|
diffusers.utils.import_utils._safetensors_available = True
|
||||||
|
|
||||||
def test_accelerate_loading_error_message(self):
|
def test_accelerate_loading_error_message(self):
|
||||||
with self.assertRaises(ValueError) as error_context:
|
with self.assertRaises(ValueError) as error_context:
|
||||||
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
|
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
|
||||||
|
@ -60,6 +68,37 @@ class ModelUtilsTest(unittest.TestCase):
|
||||||
if p1.data.ne(p2.data).sum() > 0:
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
assert False, "Parameters not the same!"
|
assert False, "Parameters not the same!"
|
||||||
|
|
||||||
|
def test_one_request_upon_cached(self):
|
||||||
|
# TODO: For some reason this test fails on MPS where no HEAD call is made.
|
||||||
|
if torch_device == "mps":
|
||||||
|
return
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
|
||||||
|
diffusers.utils.import_utils._safetensors_available = False
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
with requests_mock.mock(real_http=True) as m:
|
||||||
|
UNet2DConditionModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
|
||||||
|
)
|
||||||
|
|
||||||
|
download_requests = [r.method for r in m.request_history]
|
||||||
|
assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
|
||||||
|
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
|
||||||
|
|
||||||
|
with requests_mock.mock(real_http=True) as m:
|
||||||
|
UNet2DConditionModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_requests = [r.method for r in m.request_history]
|
||||||
|
assert (
|
||||||
|
"HEAD" == cache_requests[0] and len(cache_requests) == 1
|
||||||
|
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||||
|
|
||||||
|
diffusers.utils.import_utils._safetensors_available = True
|
||||||
|
|
||||||
|
|
||||||
class ModelTesterMixin:
|
class ModelTesterMixin:
|
||||||
def test_from_save_pretrained(self):
|
def test_from_save_pretrained(self):
|
||||||
|
|
|
@ -25,6 +25,7 @@ import unittest.mock as mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
|
import requests_mock
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
@ -61,14 +62,44 @@ torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
|
|
||||||
class DownloadTests(unittest.TestCase):
|
class DownloadTests(unittest.TestCase):
|
||||||
|
def test_one_request_upon_cached(self):
|
||||||
|
# TODO: For some reason this test fails on MPS where no HEAD call is made.
|
||||||
|
if torch_device == "mps":
|
||||||
|
return
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
with requests_mock.mock(real_http=True) as m:
|
||||||
|
DiffusionPipeline.download(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||||
|
)
|
||||||
|
|
||||||
|
download_requests = [r.method for r in m.request_history]
|
||||||
|
assert download_requests.count("HEAD") == 16, "15 calls to files + send_telemetry"
|
||||||
|
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
|
||||||
|
assert (
|
||||||
|
len(download_requests) == 33
|
||||||
|
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
|
||||||
|
|
||||||
|
with requests_mock.mock(real_http=True) as m:
|
||||||
|
DiffusionPipeline.download(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_requests = [r.method for r in m.request_history]
|
||||||
|
assert cache_requests.count("HEAD") == 1, "send_telemetry is only HEAD"
|
||||||
|
assert cache_requests.count("GET") == 1, "model info is only GET"
|
||||||
|
assert (
|
||||||
|
len(cache_requests) == 2
|
||||||
|
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
|
||||||
|
|
||||||
def test_download_only_pytorch(self):
|
def test_download_only_pytorch(self):
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
# pipeline has Flax weights
|
# pipeline has Flax weights
|
||||||
_ = DiffusionPipeline.from_pretrained(
|
tmpdirname = DiffusionPipeline.download(
|
||||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||||
)
|
)
|
||||||
|
|
||||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
|
|
||||||
# None of the downloaded files should be a flax file even if we have some here:
|
# None of the downloaded files should be a flax file even if we have some here:
|
||||||
|
@ -101,13 +132,13 @@ class DownloadTests(unittest.TestCase):
|
||||||
def test_download_safetensors(self):
|
def test_download_safetensors(self):
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
# pipeline has Flax weights
|
# pipeline has Flax weights
|
||||||
_ = DiffusionPipeline.from_pretrained(
|
tmpdirname = DiffusionPipeline.download(
|
||||||
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
|
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
cache_dir=tmpdirname,
|
cache_dir=tmpdirname,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
|
|
||||||
# None of the downloaded files should be a pytorch file even if we have some here:
|
# None of the downloaded files should be a pytorch file even if we have some here:
|
||||||
|
@ -204,12 +235,10 @@ class DownloadTests(unittest.TestCase):
|
||||||
|
|
||||||
other_format = ".bin" if safe_avail else ".safetensors"
|
other_format = ".bin" if safe_avail else ".safetensors"
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
StableDiffusionPipeline.from_pretrained(
|
tmpdirname = StableDiffusionPipeline.download(
|
||||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
|
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
|
||||||
)
|
)
|
||||||
all_root_files = [
|
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||||
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
|
|
||||||
]
|
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
|
|
||||||
# None of the downloaded files should be a variant file even if we have some here:
|
# None of the downloaded files should be a variant file even if we have some here:
|
||||||
|
@ -232,12 +261,10 @@ class DownloadTests(unittest.TestCase):
|
||||||
variant = "fp16"
|
variant = "fp16"
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
StableDiffusionPipeline.from_pretrained(
|
tmpdirname = StableDiffusionPipeline.download(
|
||||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
||||||
)
|
)
|
||||||
all_root_files = [
|
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||||
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
|
|
||||||
]
|
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
|
|
||||||
# None of the downloaded files should be a non-variant file even if we have some here:
|
# None of the downloaded files should be a non-variant file even if we have some here:
|
||||||
|
@ -262,14 +289,13 @@ class DownloadTests(unittest.TestCase):
|
||||||
variant = "no_ema"
|
variant = "no_ema"
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
StableDiffusionPipeline.from_pretrained(
|
tmpdirname = StableDiffusionPipeline.download(
|
||||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
||||||
)
|
)
|
||||||
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
|
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||||
all_root_files = [t[-1] for t in os.walk(snapshots)]
|
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
|
|
||||||
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet"))
|
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
|
||||||
|
|
||||||
# Some of the downloaded files should be a non-variant file, check:
|
# Some of the downloaded files should be a non-variant file, check:
|
||||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||||
|
@ -292,7 +318,7 @@ class DownloadTests(unittest.TestCase):
|
||||||
for variant in [None, "no_ema"]:
|
for variant in [None, "no_ema"]:
|
||||||
with self.assertRaises(OSError) as error_context:
|
with self.assertRaises(OSError) as error_context:
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
StableDiffusionPipeline.from_pretrained(
|
tmpdirname = StableDiffusionPipeline.from_pretrained(
|
||||||
"hf-internal-testing/stable-diffusion-broken-variants",
|
"hf-internal-testing/stable-diffusion-broken-variants",
|
||||||
cache_dir=tmpdirname,
|
cache_dir=tmpdirname,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
|
@ -302,13 +328,11 @@ class DownloadTests(unittest.TestCase):
|
||||||
|
|
||||||
# text encoder has fp16 variants so we can load it
|
# text encoder has fp16 variants so we can load it
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
tmpdirname = StableDiffusionPipeline.download(
|
||||||
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
|
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
|
||||||
)
|
)
|
||||||
assert pipe is not None
|
|
||||||
|
|
||||||
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
|
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||||
all_root_files = [t[-1] for t in os.walk(snapshots)]
|
|
||||||
files = [item for sublist in all_root_files for item in sublist]
|
files = [item for sublist in all_root_files for item in sublist]
|
||||||
|
|
||||||
# None of the downloaded files should be a non-variant file even if we have some here:
|
# None of the downloaded files should be a non-variant file even if we have some here:
|
||||||
|
@ -395,7 +419,7 @@ class CustomPipelineTests(unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_load_pipeline_from_git(self):
|
def test_download_from_git(self):
|
||||||
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||||
|
|
||||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
||||||
|
|
Loading…
Reference in New Issue