Merge branch 'main' of github.com:huggingface/diffusers
This commit is contained in:
commit
ae73d95e41
51
setup.py
51
setup.py
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -52,11 +52,11 @@ To create the package for pypi.
|
||||||
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
|
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
|
||||||
|
|
||||||
Check that you can install it in a virtualenv by running:
|
Check that you can install it in a virtualenv by running:
|
||||||
pip install -i https://testpypi.python.org/pypi transformers
|
pip install -i https://testpypi.python.org/pypi diffusers
|
||||||
|
|
||||||
Check you can run the following commands:
|
Check you can run the following commands:
|
||||||
python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
|
python -c "from diffusers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
|
||||||
python -c "from transformers import *"
|
python -c "from diffusers import *"
|
||||||
|
|
||||||
9. Upload the final version to actual pypi:
|
9. Upload the final version to actual pypi:
|
||||||
twine upload dist/* -r pypi
|
twine upload dist/* -r pypi
|
||||||
|
@ -77,36 +77,21 @@ from setuptools import find_packages, setup
|
||||||
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
||||||
_deps = [
|
_deps = [
|
||||||
"Pillow",
|
"Pillow",
|
||||||
"accelerate>=0.9.0",
|
|
||||||
"black~=22.0,>=22.3",
|
"black~=22.0,>=22.3",
|
||||||
"codecarbon==1.2.0",
|
"filelock",
|
||||||
"dataclasses",
|
"flake8>=3.8.3",
|
||||||
"datasets",
|
"huggingface-hub",
|
||||||
"GitPython<3.1.19",
|
|
||||||
"hf-doc-builder>=0.3.0",
|
|
||||||
"huggingface-hub>=0.1.0,<1.0",
|
|
||||||
"importlib_metadata",
|
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
"numpy>=1.17",
|
"numpy",
|
||||||
"pytest",
|
"pytest",
|
||||||
"pytest-timeout",
|
|
||||||
"pytest-xdist",
|
|
||||||
"python>=3.7.0",
|
|
||||||
"regex!=2019.12.17",
|
|
||||||
"requests",
|
"requests",
|
||||||
"sagemaker>=2.31.0",
|
|
||||||
"tokenizers>=0.11.1,!=0.11.3,<0.13",
|
|
||||||
"torch>=1.4",
|
"torch>=1.4",
|
||||||
"torchaudio",
|
"torchvision",
|
||||||
"tqdm>=4.27",
|
|
||||||
"unidic>=1.0.2",
|
|
||||||
"unidic_lite>=1.0.7",
|
|
||||||
"uvicorn",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# this is a lookup table with items like:
|
# this is a lookup table with items like:
|
||||||
#
|
#
|
||||||
# tokenizers: "tokenizers==0.9.4"
|
# tokenizers: "huggingface-hub==0.8.0"
|
||||||
# packaging: "packaging"
|
# packaging: "packaging"
|
||||||
#
|
#
|
||||||
# some of the values are versioned whereas others aren't.
|
# some of the values are versioned whereas others aren't.
|
||||||
|
@ -176,15 +161,17 @@ extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
|
||||||
extras["docs"] = []
|
extras["docs"] = []
|
||||||
extras["test"] = [
|
extras["test"] = [
|
||||||
"pytest",
|
"pytest",
|
||||||
"pytest-xdist",
|
|
||||||
"pytest-subtests",
|
|
||||||
"datasets",
|
|
||||||
"transformers",
|
|
||||||
]
|
]
|
||||||
extras["dev"] = extras["quality"] + extras["test"]
|
extras["dev"] = extras["quality"] + extras["test"]
|
||||||
|
|
||||||
extras["sagemaker"] = [
|
install_requires = [
|
||||||
"sagemaker", # boto3 is a required package in sagemaker
|
deps["filelock"],
|
||||||
|
deps["huggingface-hub"],
|
||||||
|
deps["numpy"],
|
||||||
|
deps["requests"],
|
||||||
|
deps["torch"],
|
||||||
|
deps["torchvision"],
|
||||||
|
deps["Pillow"],
|
||||||
]
|
]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
@ -201,7 +188,7 @@ setup(
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
python_requires=">=3.6.0",
|
python_requires=">=3.6.0",
|
||||||
install_requires=["numpy>=1.17", "packaging>=20.0", "pyyaml", "torch>=1.4.0"],
|
install_requires=install_requires,
|
||||||
extras_require=extras,
|
extras_require=extras,
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 5 - Production/Stable",
|
"Development Status :: 5 - Production/Stable",
|
||||||
|
|
|
@ -24,18 +24,19 @@ import re
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from transformers.utils import (
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
|
DIFFUSERS_CACHE,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
cached_path,
|
|
||||||
hf_bucket_url,
|
|
||||||
is_offline_mode,
|
|
||||||
is_remote_url,
|
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,6 +57,8 @@ class ConfigMixin:
|
||||||
if self.config_name is None:
|
if self.config_name is None:
|
||||||
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
||||||
kwargs["_class_name"] = self.__class__.__name__
|
kwargs["_class_name"] = self.__class__.__name__
|
||||||
|
kwargs["_diffusers_version"] = __version__
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
try:
|
try:
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
@ -90,11 +93,26 @@ class ConfigMixin:
|
||||||
self.to_json_file(output_config_file)
|
self.to_json_file(output_config_file)
|
||||||
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||||
|
config_dict = cls.get_config_dict(
|
||||||
|
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
model = cls(**init_dict)
|
||||||
|
|
||||||
|
if return_unused_kwargs:
|
||||||
|
return model, unused_kwargs
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_dict(
|
def get_config_dict(
|
||||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
@ -104,85 +122,83 @@ class ConfigMixin:
|
||||||
|
|
||||||
user_agent = {"file_type": "config"}
|
user_agent = {"file_type": "config"}
|
||||||
|
|
||||||
if is_offline_mode() and not local_files_only:
|
|
||||||
logger.info("Offline mode: forcing local_files_only=True")
|
|
||||||
local_files_only = True
|
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
|
||||||
config_file = pretrained_model_name_or_path
|
|
||||||
else:
|
|
||||||
configuration_file = cls.config_name
|
|
||||||
|
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if cls.config_name is None:
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
raise ValueError(
|
||||||
|
"`self.config_name` is not defined. Note that one should not load a config from "
|
||||||
|
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.isfile(pretrained_model_name_or_path):
|
||||||
|
config_file = pretrained_model_name_or_path
|
||||||
|
elif os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||||
|
# Load from a PyTorch checkpoint
|
||||||
|
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||||
else:
|
else:
|
||||||
config_file = hf_bucket_url(
|
raise EnvironmentError(
|
||||||
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Load from URL or cache if already cached
|
||||||
|
config_file = hf_hub_download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
filename=cls.config_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
except RepositoryNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||||
|
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||||
|
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||||
|
"`use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||||
|
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||||
|
"available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
||||||
|
)
|
||||||
|
except HTTPError as err:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||||
|
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||||
|
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
|
||||||
|
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||||
|
)
|
||||||
|
except EnvironmentError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||||
|
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||||
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
|
f"containing a {cls.config_name} file"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
# Load from URL or cache if already cached
|
|
||||||
resolved_config_file = cached_path(
|
|
||||||
config_file,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
|
||||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
|
||||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
|
||||||
"`use_auth_token=True`."
|
|
||||||
)
|
|
||||||
except RevisionNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
|
||||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
|
||||||
"available revisions."
|
|
||||||
)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
|
|
||||||
)
|
|
||||||
except HTTPError as err:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
|
||||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
|
||||||
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
|
|
||||||
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
|
||||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
|
||||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
|
||||||
f"containing a {configuration_file} file"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load config dict
|
# Load config dict
|
||||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
config_dict = cls._dict_from_json_file(config_file)
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
f"It looks like the config file at '{config_file}' is not a valid JSON file."
|
||||||
)
|
)
|
||||||
|
|
||||||
if resolved_config_file == config_file:
|
|
||||||
logger.info(f"loading configuration file {config_file}")
|
|
||||||
else:
|
|
||||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
|
||||||
|
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -208,19 +224,6 @@ class ConfigMixin:
|
||||||
|
|
||||||
return init_dict, unused_kwargs
|
return init_dict, unused_kwargs
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
|
||||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
|
||||||
|
|
||||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
|
||||||
|
|
||||||
model = cls(**init_dict)
|
|
||||||
|
|
||||||
if return_unused_kwargs:
|
|
||||||
return model, unused_kwargs
|
|
||||||
else:
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||||
with open(json_file, "r", encoding="utf-8") as reader:
|
with open(json_file, "r", encoding="utf-8") as reader:
|
||||||
|
@ -233,18 +236,9 @@ class ConfigMixin:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
@property
|
||||||
"""
|
def config(self) -> Dict[str, Any]:
|
||||||
Serializes this instance to a Python dictionary.
|
output = copy.deepcopy(self._dict_to_save)
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
|
||||||
"""
|
|
||||||
output = copy.deepcopy(self.__dict__)
|
|
||||||
|
|
||||||
# Diffusion version when serializing the model
|
|
||||||
output["diffusers_version"] = __version__
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def to_json_string(self) -> str:
|
def to_json_string(self) -> str:
|
||||||
|
|
|
@ -22,16 +22,8 @@ import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, model_info
|
from huggingface_hub import cached_download
|
||||||
|
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
|
||||||
from transformers.utils import (
|
|
||||||
HF_MODULES_CACHE,
|
|
||||||
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
|
||||||
cached_path,
|
|
||||||
hf_bucket_url,
|
|
||||||
is_offline_mode,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -219,7 +211,7 @@ def get_cached_module_file(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_module_file = cached_path(
|
resolved_module_file = cached_download(
|
||||||
module_file_or_url,
|
module_file_or_url,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
|
@ -237,7 +229,7 @@ def get_cached_module_file(
|
||||||
modules_needed = check_imports(resolved_module_file)
|
modules_needed = check_imports(resolved_module_file)
|
||||||
|
|
||||||
# Now we move the module inside our cached dynamic modules.
|
# Now we move the module inside our cached dynamic modules.
|
||||||
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||||
create_dynamic_module(full_submodule)
|
create_dynamic_module(full_submodule)
|
||||||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
||||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
||||||
|
|
|
@ -21,18 +21,15 @@ import torch
|
||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
|
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
# CHANGE to diffusers.utils
|
from .utils import (
|
||||||
from transformers.utils import (
|
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
|
DIFFUSERS_CACHE,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
cached_path,
|
|
||||||
hf_bucket_url,
|
|
||||||
is_offline_mode,
|
|
||||||
is_remote_url,
|
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
mirror = kwargs.pop("mirror", None)
|
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
|
|
||||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||||
|
|
||||||
if is_offline_mode() and not local_files_only:
|
|
||||||
logger.info("Offline mode: forcing local_files_only=True")
|
|
||||||
local_files_only = True
|
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
config_path = pretrained_model_name_or_path
|
config_path = pretrained_model_name_or_path
|
||||||
model, unused_kwargs = cls.from_config(
|
model, unused_kwargs = cls.from_config(
|
||||||
|
@ -353,79 +345,67 @@ class ModelMixin(torch.nn.Module):
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
|
||||||
archive_file = pretrained_model_name_or_path
|
|
||||||
else:
|
else:
|
||||||
filename = WEIGHTS_NAME
|
try:
|
||||||
|
# Load from URL or cache if already cached
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
filename=WEIGHTS_NAME,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
|
||||||
archive_file = hf_bucket_url(
|
except RepositoryNotFoundError:
|
||||||
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
|
raise EnvironmentError(
|
||||||
)
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||||
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||||
|
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||||
|
"login` and pass `use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||||
|
"this model name. Check the model page at "
|
||||||
|
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError:
|
||||||
|
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
|
||||||
|
except HTTPError as err:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||||
|
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||||
|
f" directory containing a file named {WEIGHTS_NAME} or"
|
||||||
|
" \nCheckout your internet connection or see how to run the library in"
|
||||||
|
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
|
)
|
||||||
|
except EnvironmentError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||||
|
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||||
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
|
f"containing a file named {WEIGHTS_NAME}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
# restore default dtype
|
||||||
# Load from URL or cache if already cached
|
state_dict = load_state_dict(model_file)
|
||||||
resolved_archive_file = cached_path(
|
|
||||||
archive_file,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
|
||||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
|
||||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
|
||||||
"login` and pass `use_auth_token=True`."
|
|
||||||
)
|
|
||||||
except RevisionNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
|
||||||
"this model name. Check the model page at "
|
|
||||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
|
||||||
)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {filename}.")
|
|
||||||
except HTTPError as err:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
|
||||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
|
||||||
f" directory containing a file named {WEIGHTS_NAME} or"
|
|
||||||
" \nCheckout your internet connection or see how to run the library in"
|
|
||||||
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
|
||||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
|
||||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
|
||||||
f"containing a file named {WEIGHTS_NAME}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if resolved_archive_file == archive_file:
|
|
||||||
logger.info(f"loading weights file {archive_file}")
|
|
||||||
else:
|
|
||||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
|
||||||
|
|
||||||
# restore default dtype
|
|
||||||
state_dict = load_state_dict(resolved_archive_file)
|
|
||||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||||
model,
|
model,
|
||||||
state_dict,
|
state_dict,
|
||||||
resolved_archive_file,
|
model_file,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,8 +20,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
# CHANGE to diffusers.utils
|
from .utils import logging, DIFFUSERS_CACHE
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||||
|
@ -76,11 +75,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||||
self.save_config(save_directory)
|
self.save_config(save_directory)
|
||||||
|
|
||||||
model_index_dict = self._dict_to_save
|
model_index_dict = self.config
|
||||||
model_index_dict.pop("_class_name")
|
model_index_dict.pop("_class_name")
|
||||||
|
model_index_dict.pop("_diffusers_version")
|
||||||
model_index_dict.pop("_module")
|
model_index_dict.pop("_module")
|
||||||
|
|
||||||
for name, (library_name, class_name) in self._dict_to_save.items():
|
for name, (library_name, class_name) in model_index_dict.items():
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||||||
|
|
||||||
# TODO: Suraj
|
# TODO: Suraj
|
||||||
|
@ -101,14 +101,36 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||||
|
r"""
|
||||||
|
Add docstrings
|
||||||
|
"""
|
||||||
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
|
||||||
# use snapshot download here to get it working from from_pretrained
|
# use snapshot download here to get it working from from_pretrained
|
||||||
if not os.path.isdir(pretrained_model_name_or_path):
|
if not os.path.isdir(pretrained_model_name_or_path):
|
||||||
cached_folder = snapshot_download(pretrained_model_name_or_path)
|
cached_folder = snapshot_download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
output_loading_info=output_loading_info,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cached_folder = pretrained_model_name_or_path
|
cached_folder = pretrained_model_name_or_path
|
||||||
|
|
||||||
config_dict = cls.get_config_dict(cached_folder)
|
config_dict = cls.get_config_dict(cached_folder)
|
||||||
|
|
||||||
|
module = config_dict["_module"]
|
||||||
|
class_name_ = config_dict["_class_name"]
|
||||||
module_candidate = config_dict["_module"]
|
module_candidate = config_dict["_module"]
|
||||||
module_candidate_name = module_candidate.replace(".py", "")
|
module_candidate_name = module_candidate.replace(".py", "")
|
||||||
|
|
||||||
|
@ -132,7 +154,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
all_importable_classes.update(LOADABLE_CLASSES[library])
|
all_importable_classes.update(LOADABLE_CLASSES[library])
|
||||||
|
|
||||||
for name, (library_name, class_name) in init_dict.items():
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
|
|
||||||
# if the model is not in diffusers or transformers, we need to load it from the hub
|
# if the model is not in diffusers or transformers, we need to load it from the hub
|
||||||
# assumes that it's a subclass of ModelMixin
|
# assumes that it's a subclass of ModelMixin
|
||||||
if library_name == module_candidate_name:
|
if library_name == module_candidate_name:
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
import os
|
||||||
|
|
||||||
|
hf_cache_home = os.path.expanduser(
|
||||||
|
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||||
|
)
|
||||||
|
default_cache_path = os.path.join(hf_cache_home, "diffusers")
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_NAME = "config.json"
|
||||||
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||||
|
DIFFUSERS_CACHE = default_cache_path
|
||||||
|
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||||
|
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryNotFoundError(HTTPError):
|
||||||
|
"""
|
||||||
|
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||||
|
not have access to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class EntryNotFoundError(HTTPError):
|
||||||
|
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
||||||
|
|
||||||
|
|
||||||
|
class RevisionNotFoundError(HTTPError):
|
||||||
|
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
|
@ -0,0 +1,344 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 Optuna, Hugging Face
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Logging utilities."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from logging import CRITICAL # NOQA
|
||||||
|
from logging import DEBUG # NOQA
|
||||||
|
from logging import ERROR # NOQA
|
||||||
|
from logging import FATAL # NOQA
|
||||||
|
from logging import INFO # NOQA
|
||||||
|
from logging import NOTSET # NOQA
|
||||||
|
from logging import WARN # NOQA
|
||||||
|
from logging import WARNING # NOQA
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from tqdm import auto as tqdm_lib
|
||||||
|
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_default_handler: Optional[logging.Handler] = None
|
||||||
|
|
||||||
|
log_levels = {
|
||||||
|
"debug": logging.DEBUG,
|
||||||
|
"info": logging.INFO,
|
||||||
|
"warning": logging.WARNING,
|
||||||
|
"error": logging.ERROR,
|
||||||
|
"critical": logging.CRITICAL,
|
||||||
|
}
|
||||||
|
|
||||||
|
_default_log_level = logging.WARNING
|
||||||
|
|
||||||
|
_tqdm_active = True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_logging_level():
|
||||||
|
"""
|
||||||
|
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
||||||
|
not - fall back to `_default_log_level`
|
||||||
|
"""
|
||||||
|
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
|
||||||
|
if env_level_str:
|
||||||
|
if env_level_str in log_levels:
|
||||||
|
return log_levels[env_level_str]
|
||||||
|
else:
|
||||||
|
logging.getLogger().warning(
|
||||||
|
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
|
||||||
|
f"has to be one of: { ', '.join(log_levels.keys()) }"
|
||||||
|
)
|
||||||
|
return _default_log_level
|
||||||
|
|
||||||
|
|
||||||
|
def _get_library_name() -> str:
|
||||||
|
|
||||||
|
return __name__.split(".")[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_library_root_logger() -> logging.Logger:
|
||||||
|
|
||||||
|
return logging.getLogger(_get_library_name())
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_library_root_logger() -> None:
|
||||||
|
|
||||||
|
global _default_handler
|
||||||
|
|
||||||
|
with _lock:
|
||||||
|
if _default_handler:
|
||||||
|
# This library has already configured the library root logger.
|
||||||
|
return
|
||||||
|
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
||||||
|
_default_handler.flush = sys.stderr.flush
|
||||||
|
|
||||||
|
# Apply our default configuration to the library root logger.
|
||||||
|
library_root_logger = _get_library_root_logger()
|
||||||
|
library_root_logger.addHandler(_default_handler)
|
||||||
|
library_root_logger.setLevel(_get_default_logging_level())
|
||||||
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_library_root_logger() -> None:
|
||||||
|
|
||||||
|
global _default_handler
|
||||||
|
|
||||||
|
with _lock:
|
||||||
|
if not _default_handler:
|
||||||
|
return
|
||||||
|
|
||||||
|
library_root_logger = _get_library_root_logger()
|
||||||
|
library_root_logger.removeHandler(_default_handler)
|
||||||
|
library_root_logger.setLevel(logging.NOTSET)
|
||||||
|
_default_handler = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_levels_dict():
|
||||||
|
return log_levels
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Return a logger with the specified name.
|
||||||
|
|
||||||
|
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if name is None:
|
||||||
|
name = _get_library_name()
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
return logging.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_verbosity() -> int:
|
||||||
|
"""
|
||||||
|
Return the current level for the 🤗 Transformers's root logger as an int.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`: The logging level.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
🤗 Transformers has following logging levels:
|
||||||
|
|
||||||
|
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
||||||
|
- 40: `diffusers.logging.ERROR`
|
||||||
|
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
||||||
|
- 20: `diffusers.logging.INFO`
|
||||||
|
- 10: `diffusers.logging.DEBUG`
|
||||||
|
|
||||||
|
</Tip>"""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
return _get_library_root_logger().getEffectiveLevel()
|
||||||
|
|
||||||
|
|
||||||
|
def set_verbosity(verbosity: int) -> None:
|
||||||
|
"""
|
||||||
|
Set the verbosity level for the 🤗 Transformers's root logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
verbosity (`int`):
|
||||||
|
Logging level, e.g., one of:
|
||||||
|
|
||||||
|
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
||||||
|
- `diffusers.logging.ERROR`
|
||||||
|
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
||||||
|
- `diffusers.logging.INFO`
|
||||||
|
- `diffusers.logging.DEBUG`
|
||||||
|
"""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().setLevel(verbosity)
|
||||||
|
|
||||||
|
|
||||||
|
def set_verbosity_info():
|
||||||
|
"""Set the verbosity to the `INFO` level."""
|
||||||
|
return set_verbosity(INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def set_verbosity_warning():
|
||||||
|
"""Set the verbosity to the `WARNING` level."""
|
||||||
|
return set_verbosity(WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
def set_verbosity_debug():
|
||||||
|
"""Set the verbosity to the `DEBUG` level."""
|
||||||
|
return set_verbosity(DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
def set_verbosity_error():
|
||||||
|
"""Set the verbosity to the `ERROR` level."""
|
||||||
|
return set_verbosity(ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
def disable_default_handler() -> None:
|
||||||
|
"""Disable the default handler of the HuggingFace Transformers's root logger."""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
|
||||||
|
assert _default_handler is not None
|
||||||
|
_get_library_root_logger().removeHandler(_default_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_default_handler() -> None:
|
||||||
|
"""Enable the default handler of the HuggingFace Transformers's root logger."""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
|
||||||
|
assert _default_handler is not None
|
||||||
|
_get_library_root_logger().addHandler(_default_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def add_handler(handler: logging.Handler) -> None:
|
||||||
|
"""adds a handler to the HuggingFace Transformers's root logger."""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
|
||||||
|
assert handler is not None
|
||||||
|
_get_library_root_logger().addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_handler(handler: logging.Handler) -> None:
|
||||||
|
"""removes given handler from the HuggingFace Transformers's root logger."""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
|
||||||
|
assert handler is not None and handler not in _get_library_root_logger().handlers
|
||||||
|
_get_library_root_logger().removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def disable_propagation() -> None:
|
||||||
|
"""
|
||||||
|
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def enable_propagation() -> None:
|
||||||
|
"""
|
||||||
|
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
|
||||||
|
prevent double logging if the root logger has been configured.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().propagate = True
|
||||||
|
|
||||||
|
|
||||||
|
def enable_explicit_format() -> None:
|
||||||
|
"""
|
||||||
|
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
|
||||||
|
```
|
||||||
|
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
||||||
|
```
|
||||||
|
All handlers currently bound to the root logger are affected by this method.
|
||||||
|
"""
|
||||||
|
handlers = _get_library_root_logger().handlers
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_format() -> None:
|
||||||
|
"""
|
||||||
|
Resets the formatting for HuggingFace Transformers's loggers.
|
||||||
|
|
||||||
|
All handlers currently bound to the root logger are affected by this method.
|
||||||
|
"""
|
||||||
|
handlers = _get_library_root_logger().handlers
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
handler.setFormatter(None)
|
||||||
|
|
||||||
|
|
||||||
|
def warning_advice(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
|
||||||
|
warning will not be printed
|
||||||
|
"""
|
||||||
|
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
|
||||||
|
if no_advisory_warnings:
|
||||||
|
return
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
logging.Logger.warning_advice = warning_advice
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyTqdm:
|
||||||
|
"""Dummy tqdm which doesn't do anything."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
|
self._iterator = args[0] if args else None
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._iterator)
|
||||||
|
|
||||||
|
def __getattr__(self, _):
|
||||||
|
"""Return empty function."""
|
||||||
|
|
||||||
|
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
||||||
|
return
|
||||||
|
|
||||||
|
return empty_fn
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type_, value, traceback):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class _tqdm_cls:
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if _tqdm_active:
|
||||||
|
return tqdm_lib.tqdm(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return EmptyTqdm(*args, **kwargs)
|
||||||
|
|
||||||
|
def set_lock(self, *args, **kwargs):
|
||||||
|
self._lock = None
|
||||||
|
if _tqdm_active:
|
||||||
|
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_lock(self):
|
||||||
|
if _tqdm_active:
|
||||||
|
return tqdm_lib.tqdm.get_lock()
|
||||||
|
|
||||||
|
|
||||||
|
tqdm = _tqdm_cls()
|
||||||
|
|
||||||
|
|
||||||
|
def is_progress_bar_enabled() -> bool:
|
||||||
|
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
||||||
|
global _tqdm_active
|
||||||
|
return bool(_tqdm_active)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_progress_bar():
|
||||||
|
"""Enable tqdm progress bar."""
|
||||||
|
global _tqdm_active
|
||||||
|
_tqdm_active = True
|
||||||
|
|
||||||
|
|
||||||
|
def disable_progress_bar():
|
||||||
|
"""Disable tqdm progress bar."""
|
||||||
|
global _tqdm_active
|
||||||
|
_tqdm_active = False
|
|
@ -24,6 +24,7 @@ import torch
|
||||||
|
|
||||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from models.vision.ddpm.modeling_ddpm import DDPM
|
from models.vision.ddpm.modeling_ddpm import DDPM
|
||||||
from models.vision.ddim.modeling_ddim import DDIM
|
from models.vision.ddim.modeling_ddim import DDIM
|
||||||
|
|
||||||
|
@ -78,6 +79,45 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||||
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
|
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigTester(unittest.TestCase):
|
||||||
|
def test_load_not_from_mixin(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ConfigMixin.from_config("dummy_path")
|
||||||
|
|
||||||
|
def test_save_load(self):
|
||||||
|
|
||||||
|
class SampleObject(ConfigMixin):
|
||||||
|
config_name = "config.json"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
a=2,
|
||||||
|
b=5,
|
||||||
|
c=(2, 5),
|
||||||
|
d="for diffusion",
|
||||||
|
e=[1, 3],
|
||||||
|
):
|
||||||
|
self.register(a=a, b=b, c=c, d=d, e=e)
|
||||||
|
|
||||||
|
obj = SampleObject()
|
||||||
|
config = obj.config
|
||||||
|
|
||||||
|
assert config["a"] == 2
|
||||||
|
assert config["b"] == 5
|
||||||
|
assert config["c"] == (2, 5)
|
||||||
|
assert config["d"] == "for diffusion"
|
||||||
|
assert config["e"] == [1, 3]
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
obj.save_config(tmpdirname)
|
||||||
|
new_obj = SampleObject.from_config(tmpdirname)
|
||||||
|
new_config = new_obj.config
|
||||||
|
|
||||||
|
assert config.pop("c") == (2, 5) # instantiated as tuple
|
||||||
|
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
|
||||||
|
assert config == new_config
|
||||||
|
|
||||||
|
|
||||||
class ModelTesterMixin(unittest.TestCase):
|
class ModelTesterMixin(unittest.TestCase):
|
||||||
@property
|
@property
|
||||||
def dummy_input(self):
|
def dummy_input(self):
|
||||||
|
|
Loading…
Reference in New Issue