diff --git a/setup.py b/setup.py index 93a5dd76..fa523db2 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,7 @@ _deps = [ "filelock", "flake8>=3.8.3", "hf-doc-builder>=0.3.0", - "huggingface-hub", + "huggingface-hub>=0.8.1,<1.0", "importlib_metadata", "isort>=5.5.4", "modelcards==0.1.4", diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 991a078f..646bbf22 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -23,17 +23,11 @@ from collections import OrderedDict from typing import Any, Dict, Tuple, Union from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError from . import __version__ -from .utils import ( - DIFFUSERS_CACHE, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, - EntryNotFoundError, - RepositoryNotFoundError, - RevisionNotFoundError, - logging, -) +from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging logger = logging.get_logger(__name__) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 9874686d..97179425 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -5,10 +5,11 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", "black": "black~=22.0,>=22.3", + "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub", + "huggingface-hub": "huggingface-hub>=0.8.1,<1.0", "importlib_metadata": "importlib_metadata", "isort": "isort>=5.5.4", "modelcards": "modelcards==0.1.4", @@ -16,6 +17,6 @@ deps = { "pytest": "pytest", "regex": "regex!=2019.12.17", "requests": "requests", - "torch": "torch>=1.4", "tensorboard": "tensorboard", + "torch": "torch>=1.4", } diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 44a696ca..3bbc298c 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -21,17 +21,10 @@ import torch from torch import Tensor, device from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError -from .utils import ( - CONFIG_NAME, - DIFFUSERS_CACHE, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, - EntryNotFoundError, - RepositoryNotFoundError, - RevisionNotFoundError, - logging, -) +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging WEIGHTS_NAME = "diffusion_pytorch_model.bin" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index efbd89e2..c7927297 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -1,4 +1,8 @@ -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,13 +15,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib + + import os -from collections import OrderedDict - -import importlib_metadata -from requests.exceptions import HTTPError +from .import_utils import ( + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + USE_JAX, + USE_TF, + USE_TORCH, + DummyObject, + is_flax_available, + is_inflect_available, + is_scipy_available, + is_tf_available, + is_torch_available, + is_transformers_available, + is_unidecode_available, + requires_backends, +) from .logging import get_logger @@ -35,135 +52,3 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) - - -_transformers_available = importlib.util.find_spec("transformers") is not None -try: - _transformers_version = importlib_metadata.version("transformers") - logger.debug(f"Successfully imported transformers version {_transformers_version}") -except importlib_metadata.PackageNotFoundError: - _transformers_available = False - - -_inflect_available = importlib.util.find_spec("inflect") is not None -try: - _inflect_version = importlib_metadata.version("inflect") - logger.debug(f"Successfully imported inflect version {_inflect_version}") -except importlib_metadata.PackageNotFoundError: - _inflect_available = False - - -_unidecode_available = importlib.util.find_spec("unidecode") is not None -try: - _unidecode_version = importlib_metadata.version("unidecode") - logger.debug(f"Successfully imported unidecode version {_unidecode_version}") -except importlib_metadata.PackageNotFoundError: - _unidecode_available = False - - -_modelcards_available = importlib.util.find_spec("modelcards") is not None -try: - _modelcards_version = importlib_metadata.version("modelcards") - logger.debug(f"Successfully imported modelcards version {_modelcards_version}") -except importlib_metadata.PackageNotFoundError: - _modelcards_available = False - - -_scipy_available = importlib.util.find_spec("scipy") is not None -try: - _scipy_version = importlib_metadata.version("scipy") - logger.debug(f"Successfully imported transformers version {_scipy_version}") -except importlib_metadata.PackageNotFoundError: - _scipy_available = False - - -def is_transformers_available(): - return _transformers_available - - -def is_inflect_available(): - return _inflect_available - - -def is_unidecode_available(): - return _unidecode_available - - -def is_modelcards_available(): - return _modelcards_available - - -def is_scipy_available(): - return _scipy_available - - -class RepositoryNotFoundError(HTTPError): - """ - Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does - not have access to. - """ - - -class EntryNotFoundError(HTTPError): - """Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename.""" - - -class RevisionNotFoundError(HTTPError): - """Raised when trying to access a hf.co URL with a valid repository but an invalid revision.""" - - -TRANSFORMERS_IMPORT_ERROR = """ -{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip -install transformers` -""" - - -UNIDECODE_IMPORT_ERROR = """ -{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install -Unidecode` -""" - - -INFLECT_IMPORT_ERROR = """ -{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install -inflect` -""" - - -SCIPY_IMPORT_ERROR = """ -{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install -scipy` -""" - - -BACKENDS_MAPPING = OrderedDict( - [ - ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), - ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), - ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), - ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ] -) - - -def requires_backends(obj, backends): - if not isinstance(backends, (list, tuple)): - backends = [backends] - - name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] - if failed: - raise ImportError("".join(failed)) - - -class DummyObject(type): - """ - Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by - `requires_backend` each time a user tries to access any method of that class. - """ - - def __getattr__(cls, key): - if key.startswith("_"): - return super().__getattr__(cls, key) - requires_backends(cls, cls._backends) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py new file mode 100644 index 00000000..05068b6d --- /dev/null +++ b/src/diffusers/utils/import_utils.py @@ -0,0 +1,255 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" +import importlib.util +import os +import sys +from collections import OrderedDict + +from packaging import version + +from . import logging + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() + +_torch_version = "N/A" +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +_tf_version = "N/A" +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + "tensorflow-aarch64", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") + _tf_available = False + else: + logger.info(f"TensorFlow version {_tf_version} available.") +else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + + +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + if _flax_available: + try: + _jax_version = importlib_metadata.version("jax") + _flax_version = importlib_metadata.version("flax") + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + except importlib_metadata.PackageNotFoundError: + _flax_available = False +else: + _flax_available = False + + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + + +_modelcards_available = importlib.util.find_spec("modelcards") is not None +try: + _modelcards_version = importlib_metadata.version("modelcards") + logger.debug(f"Successfully imported modelcards version {_modelcards_version}") +except importlib_metadata.PackageNotFoundError: + _modelcards_available = False + + +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported transformers version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + + +def is_torch_available(): + return _torch_available + + +def is_tf_available(): + return _tf_available + + +def is_flax_available(): + return _flax_available + + +def is_transformers_available(): + return _transformers_available + + +def is_inflect_available(): + return _inflect_available + + +def is_unidecode_available(): + return _unidecode_available + + +def is_modelcards_available(): + return _modelcards_available + + +def is_scipy_available(): + return _scipy_available + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +""" + +# docstyle-ignore +INFLECT_IMPORT_ERROR = """ +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install +scipy` +""" + +# docstyle-ignore +TENSORFLOW_IMPORT_ERROR = """ +{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +""" + +# docstyle-ignore +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + +# docstyle-ignore +UNIDECODE_IMPORT_ERROR = """ +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` +""" + + +BACKENDS_MAPPING = OrderedDict( + [ + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) diff --git a/utils/check_repo.py b/utils/check_repo.py index 83355f3d..80c63a79 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -22,14 +22,13 @@ from collections import OrderedDict from difflib import get_close_matches from pathlib import Path -from transformers import is_flax_available, is_tf_available, is_torch_available -from transformers.models.auto import get_values -from transformers.utils import ENV_VARS_TRUE_VALUES +from diffusers.models.auto import get_values +from diffusers.utils import ENV_VARS_TRUE_VALUES, is_flax_available, is_tf_available, is_torch_available # All paths are set with the intent you should run this script from the root of the repo with the command # python utils/check_repo.py -PATH_TO_TRANSFORMERS = "src/transformers" +PATH_TO_DIFFUSERS = "src/diffusers" PATH_TO_TESTS = "tests" PATH_TO_DOC = "docs/source/en" @@ -200,17 +199,17 @@ MODEL_TYPE_TO_DOC_MAPPING = OrderedDict( # This is to make sure the transformers module imported is the one in the repo. spec = importlib.util.spec_from_file_location( - "transformers", - os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), - submodule_search_locations=[PATH_TO_TRANSFORMERS], + "diffusers", + os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), + submodule_search_locations=[PATH_TO_DIFFUSERS], ) -transformers = spec.loader.load_module() +diffusers = spec.loader.load_module() def check_model_list(): """Check the model list inside the transformers library.""" - # Get the models from the directory structure of `src/transformers/models/` - models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models") + # Get the models from the directory structure of `src/diffusers/models/` + models_dir = os.path.join(PATH_TO_DIFFUSERS, "models") _models = [] for model in os.listdir(models_dir): model_dir = os.path.join(models_dir, model) @@ -218,7 +217,7 @@ def check_model_list(): _models.append(model) # Get the models from the directory structure of `src/transformers/models/` - models = [model for model in dir(transformers.models) if not model.startswith("__")] + models = [model for model in dir(diffusers.models) if not model.startswith("__")] missing_models = sorted(list(set(_models).difference(models))) if missing_models: @@ -256,10 +255,10 @@ def get_model_modules(): "modeling_vision_encoder_decoder", ] modules = [] - for model in dir(transformers.models): + for model in dir(diffusers.models): # There are some magic dunder attributes in the dir, we ignore them if not model.startswith("__"): - model_module = getattr(transformers.models, model) + model_module = getattr(diffusers.models, model) for submodule in dir(model_module): if submodule.startswith("modeling") and submodule not in _ignore_modules: modeling_module = getattr(model_module, submodule) @@ -271,7 +270,7 @@ def get_model_modules(): def get_models(module, include_pretrained=False): """Get the objects in module that are models.""" models = [] - model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin) + model_classes = (diffusers.ModelMixin, diffusers.TFModelMixin, diffusers.FlaxModelMixin) for attr_name in dir(module): if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name): continue @@ -299,7 +298,7 @@ def is_a_private_model(model): def check_models_are_in_init(): """Checks all models defined in the library are in the main init.""" models_not_in_init = [] - dir_transformers = dir(transformers) + dir_transformers = dir(diffusers) for module in get_model_modules(): models_not_in_init += [ model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers @@ -419,17 +418,17 @@ def get_all_auto_configured_models(): """Return the list of all models in at least one auto class.""" result = set() # To avoid duplicates we concatenate all model classes in a set. if is_torch_available(): - for attr_name in dir(transformers.models.auto.modeling_auto): + for attr_name in dir(diffusers.models.auto.modeling_auto): if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"): - result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name))) + result = result | set(get_values(getattr(diffusers.models.auto.modeling_auto, attr_name))) if is_tf_available(): - for attr_name in dir(transformers.models.auto.modeling_tf_auto): + for attr_name in dir(diffusers.models.auto.modeling_tf_auto): if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"): - result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name))) + result = result | set(get_values(getattr(diffusers.models.auto.modeling_tf_auto, attr_name))) if is_flax_available(): - for attr_name in dir(transformers.models.auto.modeling_flax_auto): + for attr_name in dir(diffusers.models.auto.modeling_flax_auto): if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"): - result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name))) + result = result | set(get_values(getattr(diffusers.models.auto.modeling_flax_auto, attr_name))) return [cls for cls in result] @@ -636,8 +635,8 @@ def ignore_undocumented(name): ): return True # Submodules are not documented. - if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile( - os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py") + if os.path.isdir(os.path.join(PATH_TO_DIFFUSERS, name)) or os.path.isfile( + os.path.join(PATH_TO_DIFFUSERS, f"{name}.py") ): return True # All load functions are not documented. @@ -660,8 +659,8 @@ def ignore_undocumented(name): def check_all_objects_are_documented(): """Check all models are properly documented.""" documented_objs = find_all_documented_objects() - modules = transformers._modules - objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")] + modules = diffusers._modules + objects = [c for c in dir(diffusers) if c not in modules and not c.startswith("_")] undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)] if len(undocumented_objs) > 0: raise Exception( @@ -677,7 +676,7 @@ def check_model_type_doc_match(): model_doc_folder = Path(PATH_TO_DOC) / "model_doc" model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")] - model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys()) + model_types = list(diffusers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys()) model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types] errors = [] @@ -723,7 +722,7 @@ def is_rst_docstring(docstring): def check_docstrings_are_in_md(): """Check all docstrings are in md""" files_with_rst = [] - for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"): + for file in Path(PATH_TO_DIFFUSERS).glob("**/*.py"): with open(file, "r") as f: code = f.read() docstrings = code.split('"""')