Add `is_torch_available`, `is_flax_available` (#204)

* Add is_<framework>_available, refactor import utils

* deps

* quality
This commit is contained in:
Anton Lozhkov 2022-08-17 16:47:20 +02:00 committed by GitHub
parent ed22b4fd07
commit df90f0ce98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 312 additions and 185 deletions

View File

@ -83,7 +83,7 @@ _deps = [
"filelock",
"flake8>=3.8.3",
"hf-doc-builder>=0.3.0",
"huggingface-hub",
"huggingface-hub>=0.8.1,<1.0",
"importlib_metadata",
"isort>=5.5.4",
"modelcards==0.1.4",

View File

@ -23,17 +23,11 @@ from collections import OrderedDict
from typing import Any, Dict, Tuple, Union
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from . import __version__
from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
logging,
)
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
logger = logging.get_logger(__name__)

View File

@ -5,10 +5,11 @@ deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0",
"black": "black~=22.0,>=22.3",
"datasets": "datasets",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub",
"huggingface-hub": "huggingface-hub>=0.8.1,<1.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"modelcards": "modelcards==0.1.4",
@ -16,6 +17,6 @@ deps = {
"pytest": "pytest",
"regex": "regex!=2019.12.17",
"requests": "requests",
"torch": "torch>=1.4",
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
}

View File

@ -21,17 +21,10 @@ import torch
from torch import Tensor, device
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
logging,
)
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
WEIGHTS_NAME = "diffusion_pytorch_model.bin"

View File

@ -1,4 +1,8 @@
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,13 +15,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from collections import OrderedDict
import importlib_metadata
from requests.exceptions import HTTPError
from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES,
USE_JAX,
USE_TF,
USE_TORCH,
DummyObject,
is_flax_available,
is_inflect_available,
is_scipy_available,
is_tf_available,
is_torch_available,
is_transformers_available,
is_unidecode_available,
requires_backends,
)
from .logging import get_logger
@ -35,135 +52,3 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
_transformers_available = importlib.util.find_spec("transformers") is not None
try:
_transformers_version = importlib_metadata.version("transformers")
logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
_transformers_available = False
_inflect_available = importlib.util.find_spec("inflect") is not None
try:
_inflect_version = importlib_metadata.version("inflect")
logger.debug(f"Successfully imported inflect version {_inflect_version}")
except importlib_metadata.PackageNotFoundError:
_inflect_available = False
_unidecode_available = importlib.util.find_spec("unidecode") is not None
try:
_unidecode_version = importlib_metadata.version("unidecode")
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
except importlib_metadata.PackageNotFoundError:
_unidecode_available = False
_modelcards_available = importlib.util.find_spec("modelcards") is not None
try:
_modelcards_version = importlib_metadata.version("modelcards")
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
except importlib_metadata.PackageNotFoundError:
_modelcards_available = False
_scipy_available = importlib.util.find_spec("scipy") is not None
try:
_scipy_version = importlib_metadata.version("scipy")
logger.debug(f"Successfully imported transformers version {_scipy_version}")
except importlib_metadata.PackageNotFoundError:
_scipy_available = False
def is_transformers_available():
return _transformers_available
def is_inflect_available():
return _inflect_available
def is_unidecode_available():
return _unidecode_available
def is_modelcards_available():
return _modelcards_available
def is_scipy_available():
return _scipy_available
class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""
class EntryNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`
"""
UNIDECODE_IMPORT_ERROR = """
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
Unidecode`
"""
INFLECT_IMPORT_ERROR = """
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
inflect`
"""
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
scipy`
"""
BACKENDS_MAPPING = OrderedDict(
[
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
]
)
def requires_backends(obj, backends):
if not isinstance(backends, (list, tuple)):
backends = [backends]
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
raise ImportError("".join(failed))
class DummyObject(type):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""
def __getattr__(cls, key):
if key.startswith("_"):
return super().__getattr__(cls, key)
requires_backends(cls, cls._backends)

View File

@ -0,0 +1,255 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Import utilities: Utilities related to imports and our lazy inits.
"""
import importlib.util
import os
import sys
from collections import OrderedDict
from packaging import version
from . import logging
# The package importlib_metadata is in a different place, depending on the python version.
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
_torch_version = "N/A"
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
_tf_version = "N/A"
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available:
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
"tensorflow-aarch64",
)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
_tf_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.")
_tf_available = False
else:
logger.info(f"TensorFlow version {_tf_version} available.")
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
try:
_jax_version = importlib_metadata.version("jax")
_flax_version = importlib_metadata.version("flax")
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
except importlib_metadata.PackageNotFoundError:
_flax_available = False
else:
_flax_available = False
_transformers_available = importlib.util.find_spec("transformers") is not None
try:
_transformers_version = importlib_metadata.version("transformers")
logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
_transformers_available = False
_inflect_available = importlib.util.find_spec("inflect") is not None
try:
_inflect_version = importlib_metadata.version("inflect")
logger.debug(f"Successfully imported inflect version {_inflect_version}")
except importlib_metadata.PackageNotFoundError:
_inflect_available = False
_unidecode_available = importlib.util.find_spec("unidecode") is not None
try:
_unidecode_version = importlib_metadata.version("unidecode")
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
except importlib_metadata.PackageNotFoundError:
_unidecode_available = False
_modelcards_available = importlib.util.find_spec("modelcards") is not None
try:
_modelcards_version = importlib_metadata.version("modelcards")
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
except importlib_metadata.PackageNotFoundError:
_modelcards_available = False
_scipy_available = importlib.util.find_spec("scipy") is not None
try:
_scipy_version = importlib_metadata.version("scipy")
logger.debug(f"Successfully imported transformers version {_scipy_version}")
except importlib_metadata.PackageNotFoundError:
_scipy_available = False
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available
def is_flax_available():
return _flax_available
def is_transformers_available():
return _transformers_available
def is_inflect_available():
return _inflect_available
def is_unidecode_available():
return _unidecode_available
def is_modelcards_available():
return _modelcards_available
def is_scipy_available():
return _scipy_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
"""
# docstyle-ignore
INFLECT_IMPORT_ERROR = """
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
inflect`
"""
# docstyle-ignore
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""
# docstyle-ignore
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
scipy`
"""
# docstyle-ignore
TENSORFLOW_IMPORT_ERROR = """
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
"""
# docstyle-ignore
TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`
"""
# docstyle-ignore
UNIDECODE_IMPORT_ERROR = """
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
Unidecode`
"""
BACKENDS_MAPPING = OrderedDict(
[
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
]
)
def requires_backends(obj, backends):
if not isinstance(backends, (list, tuple)):
backends = [backends]
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
raise ImportError("".join(failed))
class DummyObject(type):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""
def __getattr__(cls, key):
if key.startswith("_"):
return super().__getattr__(cls, key)
requires_backends(cls, cls._backends)

View File

@ -22,14 +22,13 @@ from collections import OrderedDict
from difflib import get_close_matches
from pathlib import Path
from transformers import is_flax_available, is_tf_available, is_torch_available
from transformers.models.auto import get_values
from transformers.utils import ENV_VARS_TRUE_VALUES
from diffusers.models.auto import get_values
from diffusers.utils import ENV_VARS_TRUE_VALUES, is_flax_available, is_tf_available, is_torch_available
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_repo.py
PATH_TO_TRANSFORMERS = "src/transformers"
PATH_TO_DIFFUSERS = "src/diffusers"
PATH_TO_TESTS = "tests"
PATH_TO_DOC = "docs/source/en"
@ -200,17 +199,17 @@ MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
submodule_search_locations=[PATH_TO_TRANSFORMERS],
"diffusers",
os.path.join(PATH_TO_DIFFUSERS, "__init__.py"),
submodule_search_locations=[PATH_TO_DIFFUSERS],
)
transformers = spec.loader.load_module()
diffusers = spec.loader.load_module()
def check_model_list():
"""Check the model list inside the transformers library."""
# Get the models from the directory structure of `src/transformers/models/`
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
# Get the models from the directory structure of `src/diffusers/models/`
models_dir = os.path.join(PATH_TO_DIFFUSERS, "models")
_models = []
for model in os.listdir(models_dir):
model_dir = os.path.join(models_dir, model)
@ -218,7 +217,7 @@ def check_model_list():
_models.append(model)
# Get the models from the directory structure of `src/transformers/models/`
models = [model for model in dir(transformers.models) if not model.startswith("__")]
models = [model for model in dir(diffusers.models) if not model.startswith("__")]
missing_models = sorted(list(set(_models).difference(models)))
if missing_models:
@ -256,10 +255,10 @@ def get_model_modules():
"modeling_vision_encoder_decoder",
]
modules = []
for model in dir(transformers.models):
for model in dir(diffusers.models):
# There are some magic dunder attributes in the dir, we ignore them
if not model.startswith("__"):
model_module = getattr(transformers.models, model)
model_module = getattr(diffusers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
@ -271,7 +270,7 @@ def get_model_modules():
def get_models(module, include_pretrained=False):
"""Get the objects in module that are models."""
models = []
model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin)
model_classes = (diffusers.ModelMixin, diffusers.TFModelMixin, diffusers.FlaxModelMixin)
for attr_name in dir(module):
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
continue
@ -299,7 +298,7 @@ def is_a_private_model(model):
def check_models_are_in_init():
"""Checks all models defined in the library are in the main init."""
models_not_in_init = []
dir_transformers = dir(transformers)
dir_transformers = dir(diffusers)
for module in get_model_modules():
models_not_in_init += [
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
@ -419,17 +418,17 @@ def get_all_auto_configured_models():
"""Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available():
for attr_name in dir(transformers.models.auto.modeling_auto):
for attr_name in dir(diffusers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
result = result | set(get_values(getattr(diffusers.models.auto.modeling_auto, attr_name)))
if is_tf_available():
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
for attr_name in dir(diffusers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
result = result | set(get_values(getattr(diffusers.models.auto.modeling_tf_auto, attr_name)))
if is_flax_available():
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
for attr_name in dir(diffusers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
result = result | set(get_values(getattr(diffusers.models.auto.modeling_flax_auto, attr_name)))
return [cls for cls in result]
@ -636,8 +635,8 @@ def ignore_undocumented(name):
):
return True
# Submodules are not documented.
if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile(
os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py")
if os.path.isdir(os.path.join(PATH_TO_DIFFUSERS, name)) or os.path.isfile(
os.path.join(PATH_TO_DIFFUSERS, f"{name}.py")
):
return True
# All load functions are not documented.
@ -660,8 +659,8 @@ def ignore_undocumented(name):
def check_all_objects_are_documented():
"""Check all models are properly documented."""
documented_objs = find_all_documented_objects()
modules = transformers._modules
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
modules = diffusers._modules
objects = [c for c in dir(diffusers) if c not in modules and not c.startswith("_")]
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
if len(undocumented_objs) > 0:
raise Exception(
@ -677,7 +676,7 @@ def check_model_type_doc_match():
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")]
model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
model_types = list(diffusers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
errors = []
@ -723,7 +722,7 @@ def is_rst_docstring(docstring):
def check_docstrings_are_in_md():
"""Check all docstrings are in md"""
files_with_rst = []
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
for file in Path(PATH_TO_DIFFUSERS).glob("**/*.py"):
with open(file, "r") as f:
code = f.read()
docstrings = code.split('"""')