Add `is_torch_available`, `is_flax_available` (#204)
* Add is_<framework>_available, refactor import utils * deps * quality
This commit is contained in:
parent
ed22b4fd07
commit
df90f0ce98
2
setup.py
2
setup.py
|
@ -83,7 +83,7 @@ _deps = [
|
|||
"filelock",
|
||||
"flake8>=3.8.3",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub",
|
||||
"huggingface-hub>=0.8.1,<1.0",
|
||||
"importlib_metadata",
|
||||
"isort>=5.5.4",
|
||||
"modelcards==0.1.4",
|
||||
|
|
|
@ -23,17 +23,11 @@ from collections import OrderedDict
|
|||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
logging,
|
||||
)
|
||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
|
@ -5,10 +5,11 @@ deps = {
|
|||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.11.0",
|
||||
"black": "black~=22.0,>=22.3",
|
||||
"datasets": "datasets",
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub",
|
||||
"huggingface-hub": "huggingface-hub>=0.8.1,<1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"isort": "isort>=5.5.4",
|
||||
"modelcards": "modelcards==0.1.4",
|
||||
|
@ -16,6 +17,6 @@ deps = {
|
|||
"pytest": "pytest",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"torch": "torch>=1.4",
|
||||
"tensorboard": "tensorboard",
|
||||
"torch": "torch>=1.4",
|
||||
}
|
||||
|
|
|
@ -21,17 +21,10 @@ import torch
|
|||
from torch import Tensor, device
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
logging,
|
||||
)
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
|
||||
|
||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -11,13 +15,26 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import importlib_metadata
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
USE_JAX,
|
||||
USE_TF,
|
||||
USE_TORCH,
|
||||
DummyObject,
|
||||
is_flax_available,
|
||||
is_inflect_available,
|
||||
is_scipy_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_unidecode_available,
|
||||
requires_backends,
|
||||
)
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
|
@ -35,135 +52,3 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
|||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
_transformers_version = importlib_metadata.version("transformers")
|
||||
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
|
||||
_inflect_available = importlib.util.find_spec("inflect") is not None
|
||||
try:
|
||||
_inflect_version = importlib_metadata.version("inflect")
|
||||
logger.debug(f"Successfully imported inflect version {_inflect_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_inflect_available = False
|
||||
|
||||
|
||||
_unidecode_available = importlib.util.find_spec("unidecode") is not None
|
||||
try:
|
||||
_unidecode_version = importlib_metadata.version("unidecode")
|
||||
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_unidecode_available = False
|
||||
|
||||
|
||||
_modelcards_available = importlib.util.find_spec("modelcards") is not None
|
||||
try:
|
||||
_modelcards_version = importlib_metadata.version("modelcards")
|
||||
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_modelcards_available = False
|
||||
|
||||
|
||||
_scipy_available = importlib.util.find_spec("scipy") is not None
|
||||
try:
|
||||
_scipy_version = importlib_metadata.version("scipy")
|
||||
logger.debug(f"Successfully imported transformers version {_scipy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scipy_available = False
|
||||
|
||||
|
||||
def is_transformers_available():
|
||||
return _transformers_available
|
||||
|
||||
|
||||
def is_inflect_available():
|
||||
return _inflect_available
|
||||
|
||||
|
||||
def is_unidecode_available():
|
||||
return _unidecode_available
|
||||
|
||||
|
||||
def is_modelcards_available():
|
||||
return _modelcards_available
|
||||
|
||||
|
||||
def is_scipy_available():
|
||||
return _scipy_available
|
||||
|
||||
|
||||
class RepositoryNotFoundError(HTTPError):
|
||||
"""
|
||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||
not have access to.
|
||||
"""
|
||||
|
||||
|
||||
class EntryNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
||||
|
||||
|
||||
class RevisionNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
||||
|
||||
|
||||
TRANSFORMERS_IMPORT_ERROR = """
|
||||
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
|
||||
install transformers`
|
||||
"""
|
||||
|
||||
|
||||
UNIDECODE_IMPORT_ERROR = """
|
||||
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
|
||||
Unidecode`
|
||||
"""
|
||||
|
||||
|
||||
INFLECT_IMPORT_ERROR = """
|
||||
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
|
||||
inflect`
|
||||
"""
|
||||
|
||||
|
||||
SCIPY_IMPORT_ERROR = """
|
||||
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
|
||||
scipy`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
|
||||
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
|
||||
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def requires_backends(obj, backends):
|
||||
if not isinstance(backends, (list, tuple)):
|
||||
backends = [backends]
|
||||
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
|
||||
class DummyObject(type):
|
||||
"""
|
||||
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
|
||||
`requires_backend` each time a user tries to access any method of that class.
|
||||
"""
|
||||
|
||||
def __getattr__(cls, key):
|
||||
if key.startswith("_"):
|
||||
return super().__getattr__(cls, key)
|
||||
requires_backends(cls, cls._backends)
|
||||
|
|
|
@ -0,0 +1,255 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Import utilities: Utilities related to imports and our lazy inits.
|
||||
"""
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
from packaging import version
|
||||
|
||||
from . import logging
|
||||
|
||||
|
||||
# The package importlib_metadata is in a different place, depending on the python version.
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
_torch_version = "N/A"
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
if _torch_available:
|
||||
try:
|
||||
_torch_version = importlib_metadata.version("torch")
|
||||
logger.info(f"PyTorch version {_torch_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
_torch_available = False
|
||||
|
||||
|
||||
_tf_version = "N/A"
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||
if _tf_available:
|
||||
candidates = (
|
||||
"tensorflow",
|
||||
"tensorflow-cpu",
|
||||
"tensorflow-gpu",
|
||||
"tf-nightly",
|
||||
"tf-nightly-cpu",
|
||||
"tf-nightly-gpu",
|
||||
"intel-tensorflow",
|
||||
"intel-tensorflow-avx512",
|
||||
"tensorflow-rocm",
|
||||
"tensorflow-macos",
|
||||
"tensorflow-aarch64",
|
||||
)
|
||||
_tf_version = None
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for pkg in candidates:
|
||||
try:
|
||||
_tf_version = importlib_metadata.version(pkg)
|
||||
break
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
pass
|
||||
_tf_available = _tf_version is not None
|
||||
if _tf_available:
|
||||
if version.parse(_tf_version) < version.parse("2"):
|
||||
logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.")
|
||||
_tf_available = False
|
||||
else:
|
||||
logger.info(f"TensorFlow version {_tf_version} available.")
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
|
||||
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
if _flax_available:
|
||||
try:
|
||||
_jax_version = importlib_metadata.version("jax")
|
||||
_flax_version = importlib_metadata.version("flax")
|
||||
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_flax_available = False
|
||||
else:
|
||||
_flax_available = False
|
||||
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
_transformers_version = importlib_metadata.version("transformers")
|
||||
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
|
||||
_inflect_available = importlib.util.find_spec("inflect") is not None
|
||||
try:
|
||||
_inflect_version = importlib_metadata.version("inflect")
|
||||
logger.debug(f"Successfully imported inflect version {_inflect_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_inflect_available = False
|
||||
|
||||
|
||||
_unidecode_available = importlib.util.find_spec("unidecode") is not None
|
||||
try:
|
||||
_unidecode_version = importlib_metadata.version("unidecode")
|
||||
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_unidecode_available = False
|
||||
|
||||
|
||||
_modelcards_available = importlib.util.find_spec("modelcards") is not None
|
||||
try:
|
||||
_modelcards_version = importlib_metadata.version("modelcards")
|
||||
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_modelcards_available = False
|
||||
|
||||
|
||||
_scipy_available = importlib.util.find_spec("scipy") is not None
|
||||
try:
|
||||
_scipy_version = importlib_metadata.version("scipy")
|
||||
logger.debug(f"Successfully imported transformers version {_scipy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scipy_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
||||
def is_transformers_available():
|
||||
return _transformers_available
|
||||
|
||||
|
||||
def is_inflect_available():
|
||||
return _inflect_available
|
||||
|
||||
|
||||
def is_unidecode_available():
|
||||
return _unidecode_available
|
||||
|
||||
|
||||
def is_modelcards_available():
|
||||
return _modelcards_available
|
||||
|
||||
|
||||
def is_scipy_available():
|
||||
return _scipy_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
installation page: https://github.com/google/flax and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
INFLECT_IMPORT_ERROR = """
|
||||
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
|
||||
inflect`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYTORCH_IMPORT_ERROR = """
|
||||
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
|
||||
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
SCIPY_IMPORT_ERROR = """
|
||||
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
|
||||
scipy`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
TENSORFLOW_IMPORT_ERROR = """
|
||||
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
|
||||
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
TRANSFORMERS_IMPORT_ERROR = """
|
||||
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
|
||||
install transformers`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
UNIDECODE_IMPORT_ERROR = """
|
||||
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
|
||||
Unidecode`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
||||
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
|
||||
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def requires_backends(obj, backends):
|
||||
if not isinstance(backends, (list, tuple)):
|
||||
backends = [backends]
|
||||
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
|
||||
class DummyObject(type):
|
||||
"""
|
||||
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
|
||||
`requires_backend` each time a user tries to access any method of that class.
|
||||
"""
|
||||
|
||||
def __getattr__(cls, key):
|
||||
if key.startswith("_"):
|
||||
return super().__getattr__(cls, key)
|
||||
requires_backends(cls, cls._backends)
|
|
@ -22,14 +22,13 @@ from collections import OrderedDict
|
|||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.utils import ENV_VARS_TRUE_VALUES
|
||||
from diffusers.models.auto import get_values
|
||||
from diffusers.utils import ENV_VARS_TRUE_VALUES, is_flax_available, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_repo.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
PATH_TO_DIFFUSERS = "src/diffusers"
|
||||
PATH_TO_TESTS = "tests"
|
||||
PATH_TO_DOC = "docs/source/en"
|
||||
|
||||
|
@ -200,17 +199,17 @@ MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
|
|||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||
"diffusers",
|
||||
os.path.join(PATH_TO_DIFFUSERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_DIFFUSERS],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
diffusers = spec.loader.load_module()
|
||||
|
||||
|
||||
def check_model_list():
|
||||
"""Check the model list inside the transformers library."""
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
|
||||
# Get the models from the directory structure of `src/diffusers/models/`
|
||||
models_dir = os.path.join(PATH_TO_DIFFUSERS, "models")
|
||||
_models = []
|
||||
for model in os.listdir(models_dir):
|
||||
model_dir = os.path.join(models_dir, model)
|
||||
|
@ -218,7 +217,7 @@ def check_model_list():
|
|||
_models.append(model)
|
||||
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models = [model for model in dir(transformers.models) if not model.startswith("__")]
|
||||
models = [model for model in dir(diffusers.models) if not model.startswith("__")]
|
||||
|
||||
missing_models = sorted(list(set(_models).difference(models)))
|
||||
if missing_models:
|
||||
|
@ -256,10 +255,10 @@ def get_model_modules():
|
|||
"modeling_vision_encoder_decoder",
|
||||
]
|
||||
modules = []
|
||||
for model in dir(transformers.models):
|
||||
for model in dir(diffusers.models):
|
||||
# There are some magic dunder attributes in the dir, we ignore them
|
||||
if not model.startswith("__"):
|
||||
model_module = getattr(transformers.models, model)
|
||||
model_module = getattr(diffusers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
|
@ -271,7 +270,7 @@ def get_model_modules():
|
|||
def get_models(module, include_pretrained=False):
|
||||
"""Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin)
|
||||
model_classes = (diffusers.ModelMixin, diffusers.TFModelMixin, diffusers.FlaxModelMixin)
|
||||
for attr_name in dir(module):
|
||||
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
||||
continue
|
||||
|
@ -299,7 +298,7 @@ def is_a_private_model(model):
|
|||
def check_models_are_in_init():
|
||||
"""Checks all models defined in the library are in the main init."""
|
||||
models_not_in_init = []
|
||||
dir_transformers = dir(transformers)
|
||||
dir_transformers = dir(diffusers)
|
||||
for module in get_model_modules():
|
||||
models_not_in_init += [
|
||||
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
|
||||
|
@ -419,17 +418,17 @@ def get_all_auto_configured_models():
|
|||
"""Return the list of all models in at least one auto class."""
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
if is_torch_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||
for attr_name in dir(diffusers.models.auto.modeling_auto):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||
result = result | set(get_values(getattr(diffusers.models.auto.modeling_auto, attr_name)))
|
||||
if is_tf_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||
for attr_name in dir(diffusers.models.auto.modeling_tf_auto):
|
||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||
result = result | set(get_values(getattr(diffusers.models.auto.modeling_tf_auto, attr_name)))
|
||||
if is_flax_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||
for attr_name in dir(diffusers.models.auto.modeling_flax_auto):
|
||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||
result = result | set(get_values(getattr(diffusers.models.auto.modeling_flax_auto, attr_name)))
|
||||
return [cls for cls in result]
|
||||
|
||||
|
||||
|
@ -636,8 +635,8 @@ def ignore_undocumented(name):
|
|||
):
|
||||
return True
|
||||
# Submodules are not documented.
|
||||
if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile(
|
||||
os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py")
|
||||
if os.path.isdir(os.path.join(PATH_TO_DIFFUSERS, name)) or os.path.isfile(
|
||||
os.path.join(PATH_TO_DIFFUSERS, f"{name}.py")
|
||||
):
|
||||
return True
|
||||
# All load functions are not documented.
|
||||
|
@ -660,8 +659,8 @@ def ignore_undocumented(name):
|
|||
def check_all_objects_are_documented():
|
||||
"""Check all models are properly documented."""
|
||||
documented_objs = find_all_documented_objects()
|
||||
modules = transformers._modules
|
||||
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
||||
modules = diffusers._modules
|
||||
objects = [c for c in dir(diffusers) if c not in modules and not c.startswith("_")]
|
||||
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
|
||||
if len(undocumented_objs) > 0:
|
||||
raise Exception(
|
||||
|
@ -677,7 +676,7 @@ def check_model_type_doc_match():
|
|||
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
|
||||
model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")]
|
||||
|
||||
model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
||||
model_types = list(diffusers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
||||
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
|
||||
|
||||
errors = []
|
||||
|
@ -723,7 +722,7 @@ def is_rst_docstring(docstring):
|
|||
def check_docstrings_are_in_md():
|
||||
"""Check all docstrings are in md"""
|
||||
files_with_rst = []
|
||||
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
|
||||
for file in Path(PATH_TO_DIFFUSERS).glob("**/*.py"):
|
||||
with open(file, "r") as f:
|
||||
code = f.read()
|
||||
docstrings = code.split('"""')
|
||||
|
|
Loading…
Reference in New Issue