upload some cleaning tools
This commit is contained in:
parent
d849816659
commit
95f4256fc9
|
@ -0,0 +1,166 @@
|
|||
# Initially taken from Github's Python gitignore file
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# tests and logs
|
||||
tests/fixtures/cached_*_text.txt
|
||||
logs/
|
||||
lightning_logs/
|
||||
lang_code_data/
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# vscode
|
||||
.vs
|
||||
.vscode
|
||||
|
||||
# Pycharm
|
||||
.idea
|
||||
|
||||
# TF code
|
||||
tensorflow_code
|
||||
|
||||
# Models
|
||||
proc_data
|
||||
|
||||
# examples
|
||||
runs
|
||||
/runs_old
|
||||
/wandb
|
||||
/examples/runs
|
||||
/examples/**/*.args
|
||||
/examples/rag/sweep
|
||||
|
||||
# data
|
||||
/data
|
||||
serialization_dir
|
||||
|
||||
# emacs
|
||||
*.*~
|
||||
debug.env
|
||||
|
||||
# vim
|
||||
.*.swp
|
||||
|
||||
#ctags
|
||||
tags
|
||||
|
||||
# pre-commit
|
||||
.pre-commit*
|
||||
|
||||
# .lock
|
||||
*.lock
|
||||
|
||||
# DS_Store (MacOS)
|
||||
.DS_Store
|
|
@ -0,0 +1,109 @@
|
|||
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
|
||||
|
||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
|
||||
check_dirs := models tests src utils
|
||||
|
||||
modified_only_fixup:
|
||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||
@if test -n "$(modified_py_files)"; then \
|
||||
echo "Checking/fixing $(modified_py_files)"; \
|
||||
black --preview $(modified_py_files); \
|
||||
isort $(modified_py_files); \
|
||||
flake8 $(modified_py_files); \
|
||||
else \
|
||||
echo "No library .py files were modified"; \
|
||||
fi
|
||||
|
||||
# Update src/diffusers/dependency_versions_table.py
|
||||
|
||||
deps_table_update:
|
||||
@python setup.py deps_table_update
|
||||
|
||||
deps_table_check_updated:
|
||||
@md5sum src/diffusers/dependency_versions_table.py > md5sum.saved
|
||||
@python setup.py deps_table_update
|
||||
@md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1)
|
||||
@rm md5sum.saved
|
||||
|
||||
# autogenerating code
|
||||
|
||||
autogenerate_code: deps_table_update
|
||||
|
||||
# Check that the repo is in a good state
|
||||
|
||||
repo-consistency:
|
||||
python utils/check_copies.py
|
||||
python utils/check_table.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
python utils/check_inits.py
|
||||
python utils/check_config_docstrings.py
|
||||
python utils/tests_fetcher.py --sanity_check
|
||||
|
||||
# this target runs checks on all files
|
||||
|
||||
quality:
|
||||
black --check --preview $(check_dirs)
|
||||
isort --check-only $(check_dirs)
|
||||
python utils/custom_init_isort.py --check_only
|
||||
python utils/sort_auto_mappings.py --check_only
|
||||
flake8 $(check_dirs)
|
||||
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
|
||||
extra_style_checks:
|
||||
python utils/custom_init_isort.py
|
||||
python utils/sort_auto_mappings.py
|
||||
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
|
||||
|
||||
# this target runs checks on all files and potentially modifies some of them
|
||||
|
||||
style:
|
||||
black --preview $(check_dirs)
|
||||
isort $(check_dirs)
|
||||
${MAKE} autogenerate_code
|
||||
${MAKE} extra_style_checks
|
||||
|
||||
# Super fast fix and check target that only works on relevant modified files since the branch was made
|
||||
|
||||
fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
|
||||
|
||||
# Make marked copies of snippets of codes conform to the original
|
||||
|
||||
fix-copies:
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
python utils/check_table.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
|
||||
# Run tests for the library
|
||||
|
||||
test:
|
||||
python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
|
||||
# Run tests for examples
|
||||
|
||||
test-examples:
|
||||
python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/
|
||||
|
||||
# Run tests for SageMaker DLC release
|
||||
|
||||
test-sagemaker: # install sagemaker dependencies in advance with pip install .[sagemaker]
|
||||
TEST_SAGEMAKER=True python -m pytest -n auto -s -v ./tests/sagemaker
|
||||
|
||||
|
||||
# Release stuff
|
||||
|
||||
pre-release:
|
||||
python utils/release.py
|
||||
|
||||
pre-patch:
|
||||
python utils/release.py --patch
|
||||
|
||||
post-release:
|
||||
python utils/release.py --post_release
|
||||
|
||||
post-patch:
|
||||
python utils/release.py --post_release --patch
|
|
@ -0,0 +1 @@
|
|||
ce075df80e7ba2391d63d026be165c15 src/diffusers/dependency_versions_table.py
|
161
setup.py
161
setup.py
|
@ -12,8 +12,164 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup
|
||||
from setuptools import find_packages
|
||||
"""
|
||||
Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py
|
||||
|
||||
To create the package for pypi.
|
||||
|
||||
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
|
||||
documentation.
|
||||
|
||||
If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
|
||||
for the post-release and run `make fix-copies` on the main branch as well.
|
||||
|
||||
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
|
||||
|
||||
3. Unpin specific versions from setup.py that use a git install.
|
||||
|
||||
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
|
||||
message: "Release: <VERSION>" and push.
|
||||
|
||||
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
|
||||
|
||||
6. Add a tag in git to mark the release: "git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi' "
|
||||
Push the tag to git: git push --tags origin v<RELEASE>-release
|
||||
|
||||
7. Build both the sources and the wheel. Do not change anything in setup.py between
|
||||
creating the wheel and the source distribution (obviously).
|
||||
|
||||
For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
|
||||
(this will build a wheel for the python version you use to build it).
|
||||
|
||||
For the sources, run: "python setup.py sdist"
|
||||
You should now have a /dist directory with both .whl and .tar.gz source versions.
|
||||
|
||||
8. Check that everything looks correct by uploading the package to the pypi test server:
|
||||
|
||||
twine upload dist/* -r pypitest
|
||||
(pypi suggest using twine as other methods upload files via plaintext.)
|
||||
You may have to specify the repository url, use the following command then:
|
||||
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
|
||||
|
||||
Check that you can install it in a virtualenv by running:
|
||||
pip install -i https://testpypi.python.org/pypi transformers
|
||||
|
||||
Check you can run the following commands:
|
||||
python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
|
||||
python -c "from transformers import *"
|
||||
|
||||
9. Upload the final version to actual pypi:
|
||||
twine upload dist/* -r pypi
|
||||
|
||||
10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
||||
|
||||
11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
|
||||
you need to go back to main before executing this.
|
||||
"""
|
||||
|
||||
import re
|
||||
from distutils.core import Command
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
# IMPORTANT:
|
||||
# 1. all dependencies should be listed here with their version requirements if any
|
||||
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow",
|
||||
"accelerate>=0.9.0",
|
||||
"black~=22.0,>=22.3",
|
||||
"codecarbon==1.2.0",
|
||||
"dataclasses",
|
||||
"datasets",
|
||||
"GitPython<3.1.19",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.1.0,<1.0",
|
||||
"importlib_metadata",
|
||||
"isort>=5.5.4",
|
||||
"numpy>=1.17",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.7.0",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"sagemaker>=2.31.0",
|
||||
"tokenizers>=0.11.1,!=0.11.3,<0.13",
|
||||
"torch>=1.4",
|
||||
"torchaudio",
|
||||
"tqdm>=4.27",
|
||||
"unidic>=1.0.2",
|
||||
"unidic_lite>=1.0.7",
|
||||
"uvicorn",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
#
|
||||
# tokenizers: "tokenizers==0.9.4"
|
||||
# packaging: "packaging"
|
||||
#
|
||||
# some of the values are versioned whereas others aren't.
|
||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)}
|
||||
|
||||
# since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from
|
||||
# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:
|
||||
#
|
||||
# python -c 'import sys; from diffusers.dependency_versions_table import deps; \
|
||||
# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets
|
||||
#
|
||||
# Just pass the desired package names to that script as it's shown with 2 packages above.
|
||||
#
|
||||
# If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above
|
||||
#
|
||||
# You can then feed this for example to `pip`:
|
||||
#
|
||||
# pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \
|
||||
# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets)
|
||||
#
|
||||
|
||||
|
||||
def deps_list(*pkgs):
|
||||
return [deps[pkg] for pkg in pkgs]
|
||||
|
||||
|
||||
class DepsTableUpdateCommand(Command):
|
||||
"""
|
||||
A custom distutils command that updates the dependency table.
|
||||
usage: python setup.py deps_table_update
|
||||
"""
|
||||
|
||||
description = "build runtime dependency table"
|
||||
user_options = [
|
||||
# format: (long option, short option, description).
|
||||
("dep-table-update", None, "updates src/diffusers/dependency_versions_table.py"),
|
||||
]
|
||||
|
||||
def initialize_options(self):
|
||||
pass
|
||||
|
||||
def finalize_options(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()])
|
||||
content = [
|
||||
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
|
||||
"# 1. modify the `_deps` dict in setup.py",
|
||||
"# 2. run `make deps_table_update``",
|
||||
"deps = {",
|
||||
entries,
|
||||
"}",
|
||||
"",
|
||||
]
|
||||
target = "src/diffusers/dependency_versions_table.py"
|
||||
print(f"updating {target}")
|
||||
with open(target, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write("\n".join(content))
|
||||
|
||||
|
||||
extras = {}
|
||||
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
|
||||
|
@ -61,6 +217,7 @@ setup(
|
|||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
cmdclass={"deps_table_update": DepsTableUpdateCommand},
|
||||
)
|
||||
|
||||
# Release checklist
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from .dependency_versions_table import deps
|
||||
from .utils.versions import require_version, require_version_core
|
||||
|
||||
|
||||
# define which module versions we always want to check at run time
|
||||
# (usually the ones defined in `install_requires` in setup.py)
|
||||
#
|
||||
# order specific notes:
|
||||
# - tqdm must be checked before tokenizers
|
||||
|
||||
pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
|
||||
if sys.version_info < (3, 7):
|
||||
pkgs_to_check_at_runtime.append("dataclasses")
|
||||
if sys.version_info < (3, 8):
|
||||
pkgs_to_check_at_runtime.append("importlib_metadata")
|
||||
|
||||
for pkg in pkgs_to_check_at_runtime:
|
||||
if pkg in deps:
|
||||
if pkg == "tokenizers":
|
||||
# must be loaded here, or else tqdm check may fail
|
||||
from .utils import is_tokenizers_available
|
||||
|
||||
if not is_tokenizers_available():
|
||||
continue # not required, check version only if installed
|
||||
|
||||
require_version_core(deps[pkg])
|
||||
else:
|
||||
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
||||
|
||||
|
||||
def dep_version_check(pkg, hint=None):
|
||||
require_version(deps[pkg], hint)
|
|
@ -0,0 +1,31 @@
|
|||
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
||||
# 1. modify the `_deps` dict in setup.py
|
||||
# 2. run `make deps_table_update``
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.9.0",
|
||||
"black": "black~=22.0,>=22.3",
|
||||
"codecarbon": "codecarbon==1.2.0",
|
||||
"dataclasses": "dataclasses",
|
||||
"datasets": "datasets",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"isort": "isort>=5.5.4",
|
||||
"numpy": "numpy>=1.17",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.7.0",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"sagemaker": "sagemaker>=2.31.0",
|
||||
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
|
||||
"torch": "torch>=1.4",
|
||||
"torchaudio": "torchaudio",
|
||||
"tqdm": "tqdm>=4.27",
|
||||
"unidic": "unidic>=1.0.2",
|
||||
"unidic_lite": "unidic_lite>=1.0.7",
|
||||
"uvicorn": "uvicorn",
|
||||
}
|
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
class UNetModel:
|
||||
|
||||
class UNetModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
print("I can diffuse!")
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_config_docstrings.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
|
||||
# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
|
||||
# For example, `[bert-base-uncased](https://huggingface.co/bert-base-uncased)`
|
||||
_re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
|
||||
|
||||
|
||||
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"CLIPConfig",
|
||||
"DecisionTransformerConfig",
|
||||
"EncoderDecoderConfig",
|
||||
"RagConfig",
|
||||
"SpeechEncoderDecoderConfig",
|
||||
"VisionEncoderDecoderConfig",
|
||||
"VisionTextDualEncoderConfig",
|
||||
}
|
||||
|
||||
|
||||
def check_config_docstrings_have_checkpoints():
|
||||
configs_without_checkpoint = []
|
||||
|
||||
for config_class in list(CONFIG_MAPPING.values()):
|
||||
checkpoint_found = False
|
||||
|
||||
# source code of `config_class`
|
||||
config_source = inspect.getsource(config_class)
|
||||
checkpoints = _re_checkpoint.findall(config_source)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link.
|
||||
# For example, `('bert-base-uncased', 'https://huggingface.co/bert-base-uncased')`
|
||||
ckpt_name, ckpt_link = checkpoint
|
||||
|
||||
# verify the checkpoint name corresponds to the checkpoint link
|
||||
ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
|
||||
if ckpt_link == ckpt_link_from_name:
|
||||
checkpoint_found = True
|
||||
break
|
||||
|
||||
name = config_class.__name__
|
||||
if not checkpoint_found and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
|
||||
configs_without_checkpoint.append(name)
|
||||
|
||||
if len(configs_without_checkpoint) > 0:
|
||||
message = "\n".join(sorted(configs_without_checkpoint))
|
||||
raise ValueError(f"The following configurations don't contain any valid checkpoint:\n{message}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_config_docstrings_have_checkpoints()
|
|
@ -0,0 +1,458 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
|
||||
import black
|
||||
from doc_builder.style_doc import style_docstrings_in_code
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
PATH_TO_DOCS = "docs/source/en"
|
||||
REPO_PATH = "."
|
||||
|
||||
# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
|
||||
FULL_COPIES = {
|
||||
"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py",
|
||||
"examples/flax/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py",
|
||||
}
|
||||
|
||||
|
||||
LOCALIZED_READMES = {
|
||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||
"README.md": {
|
||||
"start_prompt": "🤗 Transformers currently provides the following architectures",
|
||||
"end_prompt": "1. Want to contribute a new model?",
|
||||
"format_model_list": (
|
||||
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
|
||||
" {paper_authors}.{supplements}"
|
||||
),
|
||||
},
|
||||
"README_zh-hans.md": {
|
||||
"start_prompt": "🤗 Transformers 目前支持如下的架构",
|
||||
"end_prompt": "1. 想要贡献新的模型?",
|
||||
"format_model_list": (
|
||||
"**[{title}]({model_link})** (来自 {paper_affiliations}) 伴随论文 {paper_title_link} 由 {paper_authors}"
|
||||
" 发布。{supplements}"
|
||||
),
|
||||
},
|
||||
"README_zh-hant.md": {
|
||||
"start_prompt": "🤗 Transformers 目前支援以下的架構",
|
||||
"end_prompt": "1. 想要貢獻新的模型?",
|
||||
"format_model_list": (
|
||||
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
|
||||
" {paper_authors}.{supplements}"
|
||||
),
|
||||
},
|
||||
"README_ko.md": {
|
||||
"start_prompt": "🤗 Transformers는 다음 모델들을 제공합니다",
|
||||
"end_prompt": "1. 새로운 모델을 올리고 싶나요?",
|
||||
"format_model_list": (
|
||||
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
|
||||
" {paper_authors}.{supplements}"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _should_continue(line, indent):
|
||||
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name):
|
||||
"""Find and return the code source code of `object_name`."""
|
||||
parts = object_name.split(".")
|
||||
i = 0
|
||||
|
||||
# First let's find the module where our object lives.
|
||||
module = parts[i]
|
||||
while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
|
||||
i += 1
|
||||
if i < len(parts):
|
||||
module = os.path.join(module, parts[i])
|
||||
if i >= len(parts):
|
||||
raise ValueError(
|
||||
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
|
||||
)
|
||||
|
||||
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Now let's find the class / func in the code!
|
||||
indent = ""
|
||||
line_index = 0
|
||||
for name in parts[i + 1 :]:
|
||||
while (
|
||||
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
|
||||
):
|
||||
line_index += 1
|
||||
indent += " "
|
||||
line_index += 1
|
||||
|
||||
if line_index >= len(lines):
|
||||
raise ValueError(f" {object_name} does not match any function or class in {module}.")
|
||||
|
||||
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
||||
start_index = line_index
|
||||
while line_index < len(lines) and _should_continue(lines[line_index], indent):
|
||||
line_index += 1
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
line_index -= 1
|
||||
|
||||
code_lines = lines[start_index:line_index]
|
||||
return "".join(code_lines)
|
||||
|
||||
|
||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
||||
|
||||
|
||||
def get_indent(code):
|
||||
lines = code.split("\n")
|
||||
idx = 0
|
||||
while idx < len(lines) and len(lines[idx]) == 0:
|
||||
idx += 1
|
||||
if idx < len(lines):
|
||||
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
|
||||
return ""
|
||||
|
||||
|
||||
def blackify(code):
|
||||
"""
|
||||
Applies the black part of our `make style` command to `code`.
|
||||
"""
|
||||
has_indent = len(get_indent(code)) > 0
|
||||
if has_indent:
|
||||
code = f"class Bla:\n{code}"
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119, preview=True)
|
||||
result = black.format_str(code, mode=mode)
|
||||
result, _ = style_docstrings_in_code(result)
|
||||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
||||
|
||||
def is_copy_consistent(filename, overwrite=False):
|
||||
"""
|
||||
Check if the code commented as a copy in `filename` matches the original.
|
||||
|
||||
Return the differences or overwrites the content depending on `overwrite`.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
diffs = []
|
||||
line_index = 0
|
||||
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
|
||||
while line_index < len(lines):
|
||||
search = _re_copy_warning.search(lines[line_index])
|
||||
if search is None:
|
||||
line_index += 1
|
||||
continue
|
||||
|
||||
# There is some copied code here, let's retrieve the original.
|
||||
indent, object_name, replace_pattern = search.groups()
|
||||
theoretical_code = find_code_in_transformers(object_name)
|
||||
theoretical_indent = get_indent(theoretical_code)
|
||||
|
||||
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
|
||||
indent = theoretical_indent
|
||||
line_index = start_index
|
||||
|
||||
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
|
||||
should_continue = True
|
||||
while line_index < len(lines) and should_continue:
|
||||
line_index += 1
|
||||
if line_index >= len(lines):
|
||||
break
|
||||
line = lines[line_index]
|
||||
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
line_index -= 1
|
||||
|
||||
observed_code_lines = lines[start_index:line_index]
|
||||
observed_code = "".join(observed_code_lines)
|
||||
|
||||
# Before comparing, use the `replace_pattern` on the original code.
|
||||
if len(replace_pattern) > 0:
|
||||
patterns = replace_pattern.replace("with", "").split(",")
|
||||
patterns = [_re_replace_pattern.search(p) for p in patterns]
|
||||
for pattern in patterns:
|
||||
if pattern is None:
|
||||
continue
|
||||
obj1, obj2, option = pattern.groups()
|
||||
theoretical_code = re.sub(obj1, obj2, theoretical_code)
|
||||
if option.strip() == "all-casing":
|
||||
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
|
||||
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
|
||||
|
||||
# Blackify after replacement. To be able to do that, we need the header (class or function definition)
|
||||
# from the previous line
|
||||
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
|
||||
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
|
||||
|
||||
# Test for a diff and act accordingly.
|
||||
if observed_code != theoretical_code:
|
||||
diffs.append([object_name, start_index])
|
||||
if overwrite:
|
||||
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
|
||||
line_index = start_index + 1
|
||||
|
||||
if overwrite and len(diffs) > 0:
|
||||
# Warn the user a file has been modified.
|
||||
print(f"Detected changes, rewriting {filename}.")
|
||||
with open(filename, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines)
|
||||
return diffs
|
||||
|
||||
|
||||
def check_copies(overwrite: bool = False):
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
diffs = []
|
||||
for filename in all_files:
|
||||
new_diffs = is_copy_consistent(filename, overwrite)
|
||||
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
|
||||
if not overwrite and len(diffs) > 0:
|
||||
diff = "\n".join(diffs)
|
||||
raise Exception(
|
||||
"Found the following copy inconsistencies:\n"
|
||||
+ diff
|
||||
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
check_model_list_copy(overwrite=overwrite)
|
||||
|
||||
|
||||
def check_full_copies(overwrite: bool = False):
|
||||
diffs = []
|
||||
for target, source in FULL_COPIES.items():
|
||||
with open(source, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
with open(target, "r", encoding="utf-8") as f:
|
||||
target_code = f.read()
|
||||
if source_code != target_code:
|
||||
if overwrite:
|
||||
with open(target, "w", encoding="utf-8") as f:
|
||||
print(f"Replacing the content of {target} by the one of {source}.")
|
||||
f.write(source_code)
|
||||
else:
|
||||
diffs.append(f"- {target}: copy does not match {source}.")
|
||||
|
||||
if not overwrite and len(diffs) > 0:
|
||||
diff = "\n".join(diffs)
|
||||
raise Exception(
|
||||
"Found the following copy inconsistencies:\n"
|
||||
+ diff
|
||||
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
|
||||
|
||||
def get_model_list(filename, start_prompt, end_prompt):
|
||||
"""Extracts the model list from the README."""
|
||||
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start of the list.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
result = []
|
||||
current_line = ""
|
||||
end_index = start_index
|
||||
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
if lines[end_index].startswith("1."):
|
||||
if len(current_line) > 1:
|
||||
result.append(current_line)
|
||||
current_line = lines[end_index]
|
||||
elif len(lines[end_index]) > 1:
|
||||
current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}"
|
||||
end_index += 1
|
||||
if len(current_line) > 1:
|
||||
result.append(current_line)
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def convert_to_localized_md(model_list, localized_model_list, format_str):
|
||||
"""Convert `model_list` to each localized README."""
|
||||
|
||||
def _rep(match):
|
||||
title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups()
|
||||
return format_str.format(
|
||||
title=title,
|
||||
model_link=model_link,
|
||||
paper_affiliations=paper_affiliations,
|
||||
paper_title_link=paper_title_link,
|
||||
paper_authors=paper_authors,
|
||||
supplements=" " + supplements.strip() if len(supplements) != 0 else "",
|
||||
)
|
||||
|
||||
# This regex captures metadata from an English model description, including model title, model link,
|
||||
# affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example).
|
||||
_re_capture_meta = re.compile(
|
||||
r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$"
|
||||
)
|
||||
# This regex is used to synchronize link.
|
||||
_re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*")
|
||||
|
||||
if len(localized_model_list) == 0:
|
||||
localized_model_index = {}
|
||||
else:
|
||||
try:
|
||||
localized_model_index = {
|
||||
re.search(r"\*\*\[([^\]]*)", line).groups()[0]: line
|
||||
for line in localized_model_list.strip().split("\n")
|
||||
}
|
||||
except AttributeError:
|
||||
raise AttributeError("A model name in localized READMEs cannot be recognized.")
|
||||
|
||||
model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")]
|
||||
|
||||
# We exclude keys in localized README not in the main one.
|
||||
readmes_match = not any([k not in model_keys for k in localized_model_index])
|
||||
localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys}
|
||||
|
||||
for model in model_list.strip().split("\n"):
|
||||
title, model_link = _re_capture_title_link.search(model).groups()
|
||||
if title not in localized_model_index:
|
||||
readmes_match = False
|
||||
# Add an anchor white space behind a model description string for regex.
|
||||
# If metadata cannot be captured, the English version will be directly copied.
|
||||
localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ")
|
||||
else:
|
||||
# Synchronize link
|
||||
localized_model_index[title] = _re_capture_title_link.sub(
|
||||
f"**[{title}]({model_link})**", localized_model_index[title], count=1
|
||||
)
|
||||
|
||||
sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower())
|
||||
|
||||
return readmes_match, "\n".join(map(lambda x: x[1], sorted_index)) + "\n"
|
||||
|
||||
|
||||
def convert_readme_to_index(model_list):
|
||||
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
|
||||
return model_list.replace("https://huggingface.co/docs/transformers/", "")
|
||||
|
||||
|
||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
"""
|
||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
||||
lines.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start prompt.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
end_index = start_index
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
end_index += 1
|
||||
end_index -= 1
|
||||
|
||||
while len(lines[start_index]) <= 1:
|
||||
start_index += 1
|
||||
while len(lines[end_index]) <= 1:
|
||||
end_index -= 1
|
||||
end_index += 1
|
||||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
|
||||
# Fix potential doc links in the README
|
||||
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
|
||||
readme = f.read()
|
||||
new_readme = readme.replace("https://huggingface.co/transformers", "https://huggingface.co/docs/transformers")
|
||||
new_readme = new_readme.replace(
|
||||
"https://huggingface.co/docs/main/transformers", "https://huggingface.co/docs/transformers/main"
|
||||
)
|
||||
if new_readme != readme:
|
||||
if overwrite:
|
||||
with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(new_readme)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main README contains wrong links to the documentation of Transformers. Run `make fix-copies` to "
|
||||
"automatically fix them."
|
||||
)
|
||||
|
||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||
index_list, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.mdx"),
|
||||
start_prompt="<!--This list is updated automatically from the README",
|
||||
end_prompt="### Supported frameworks",
|
||||
)
|
||||
md_list = get_model_list(
|
||||
filename="README.md",
|
||||
start_prompt=LOCALIZED_READMES["README.md"]["start_prompt"],
|
||||
end_prompt=LOCALIZED_READMES["README.md"]["end_prompt"],
|
||||
)
|
||||
|
||||
converted_md_lists = []
|
||||
for filename, value in LOCALIZED_READMES.items():
|
||||
_start_prompt = value["start_prompt"]
|
||||
_end_prompt = value["end_prompt"]
|
||||
_format_model_list = value["format_model_list"]
|
||||
|
||||
localized_md_list = get_model_list(filename, _start_prompt, _end_prompt)
|
||||
readmes_match, converted_md_list = convert_to_localized_md(md_list, localized_md_list, _format_model_list)
|
||||
|
||||
converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt))
|
||||
|
||||
converted_md_list = convert_readme_to_index(md_list)
|
||||
if converted_md_list != index_list:
|
||||
if overwrite:
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [converted_md_list] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model list in the README changed and the list in `index.mdx` has not been updated. Run "
|
||||
"`make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
for converted_md_list in converted_md_lists:
|
||||
filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list
|
||||
|
||||
if filename == "README.md":
|
||||
continue
|
||||
if overwrite:
|
||||
_, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(REPO_PATH, filename), start_prompt=_start_prompt, end_prompt=_end_prompt
|
||||
)
|
||||
with open(os.path.join(REPO_PATH, filename), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [converted_md] + lines[end_index:])
|
||||
elif not readmes_match:
|
||||
raise ValueError(
|
||||
f"The model list in the README changed and the list in `{filename}` has not been updated. Run "
|
||||
"`make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_copies(args.fix_and_overwrite)
|
||||
check_full_copies(args.fix_and_overwrite)
|
|
@ -0,0 +1,168 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_dummies.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
# Matches from xxx import bla
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)")
|
||||
|
||||
|
||||
DUMMY_CONSTANT = """
|
||||
{0} = None
|
||||
"""
|
||||
|
||||
DUMMY_CLASS = """
|
||||
class {0}(metaclass=DummyObject):
|
||||
_backends = {1}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_backends({0}, {1})
|
||||
"""
|
||||
|
||||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def read_init():
|
||||
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get to the point we do the actual imports for type checking
|
||||
line_index = 0
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
line_index += 1
|
||||
|
||||
backend_specific_objects = {}
|
||||
# Go through the end of the file
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
while not lines[line_index].startswith(" else:"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_single_line_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
backend_specific_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
return backend_specific_objects
|
||||
|
||||
|
||||
def create_dummy_object(name, backend_name):
|
||||
"""Create the code for the dummy object corresponding to `name`."""
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
return DUMMY_FUNCTION.format(name, backend_name)
|
||||
else:
|
||||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files():
|
||||
"""Create the content of the dummy files."""
|
||||
backend_specific_objects = read_init()
|
||||
# For special correspondence backend to module name as used in the function requires_modulename
|
||||
dummy_files = {}
|
||||
|
||||
for backend, objects in backend_specific_objects.items():
|
||||
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
|
||||
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
dummy_file += "# flake8: noqa\n"
|
||||
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
|
||||
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||
dummy_files[backend] = dummy_file
|
||||
|
||||
return dummy_files
|
||||
|
||||
|
||||
def check_dummies(overwrite=False):
|
||||
"""Check if the dummy files are up to date and maybe `overwrite` with the right content."""
|
||||
dummy_files = create_dummy_files()
|
||||
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
|
||||
short_names = {"torch": "pt"}
|
||||
|
||||
# Locate actual dummy modules and read their content.
|
||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||
dummy_file_paths = {
|
||||
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
|
||||
for backend in dummy_files.keys()
|
||||
}
|
||||
|
||||
actual_dummies = {}
|
||||
for backend, file_path in dummy_file_paths.items():
|
||||
if os.path.isfile(file_path):
|
||||
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_dummies[backend] = f.read()
|
||||
else:
|
||||
actual_dummies[backend] = ""
|
||||
|
||||
for backend in dummy_files.keys():
|
||||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
if overwrite:
|
||||
print(
|
||||
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
"__init__ has new objects."
|
||||
)
|
||||
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(dummy_files[backend])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in "
|
||||
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
|
||||
"to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dummies(args.fix_and_overwrite)
|
|
@ -0,0 +1,299 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
# Catches a one-line _import_struct = {xxx}
|
||||
_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}")
|
||||
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
|
||||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
||||
# Catches a line if not is_foo_available
|
||||
_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
|
||||
# Catches a line _import_struct["bla"].append("foo")
|
||||
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
|
||||
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
|
||||
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]")
|
||||
# Catches a line with an object between quotes and a comma: "MyModel",
|
||||
_re_quote_object = re.compile('^\s+"([^"]+)",')
|
||||
# Catches a line with objects between brackets only: ["foo", "bar"],
|
||||
_re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
|
||||
# Catches a line with from foo import bar, bla, boo
|
||||
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
# Catches a line with try:
|
||||
_re_try = re.compile(r"^\s*try:")
|
||||
# Catches a line with else:
|
||||
_re_else = re.compile(r"^\s*else:")
|
||||
|
||||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def parse_init(init_file):
|
||||
"""
|
||||
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
|
||||
defined
|
||||
"""
|
||||
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
line_index = 0
|
||||
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
|
||||
line_index += 1
|
||||
|
||||
# If this is a traditional init, just return.
|
||||
if line_index >= len(lines):
|
||||
return None
|
||||
|
||||
# First grab the objects without a specific backend in _import_structure
|
||||
objects = []
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
|
||||
line = lines[line_index]
|
||||
# If we have everything on a single line, let's deal with it.
|
||||
if _re_one_line_import_struct.search(line):
|
||||
content = _re_one_line_import_struct.search(line).groups()[0]
|
||||
imports = re.findall("\[([^\]]+)\]", content)
|
||||
for imp in imports:
|
||||
objects.extend([obj[1:-1] for obj in imp.split(", ")])
|
||||
line_index += 1
|
||||
continue
|
||||
single_line_import_search = _re_import_struct_key_value.search(line)
|
||||
if single_line_import_search is not None:
|
||||
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif line.startswith(" " * 8 + '"'):
|
||||
objects.append(line[9:-3])
|
||||
line_index += 1
|
||||
|
||||
import_dict_objects = {"none": objects}
|
||||
# Let's continue with backend-specific objects in _import_structure
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
# If the line is an if not is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
# Check if the backend declaration is inside a try block:
|
||||
if _re_try.search(lines[line_index - 1]) is None:
|
||||
backend = None
|
||||
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Scroll until we hit the else block of try-except-else
|
||||
while _re_else.search(lines[line_index]) is None:
|
||||
line_index += 1
|
||||
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
|
||||
line = lines[line_index]
|
||||
if _re_import_struct_add_one.search(line) is not None:
|
||||
objects.append(_re_import_struct_add_one.search(line).groups()[0])
|
||||
elif _re_import_struct_add_many.search(line) is not None:
|
||||
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
|
||||
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif _re_between_brackets.search(line) is not None:
|
||||
imports = _re_between_brackets.search(line).groups()[0].split(", ")
|
||||
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif _re_quote_object.search(line) is not None:
|
||||
objects.append(_re_quote_object.search(line).groups()[0])
|
||||
elif line.startswith(" " * 8 + '"'):
|
||||
objects.append(line[9:-3])
|
||||
elif line.startswith(" " * 12 + '"'):
|
||||
objects.append(line[13:-3])
|
||||
line_index += 1
|
||||
|
||||
import_dict_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend
|
||||
objects = []
|
||||
while (
|
||||
line_index < len(lines)
|
||||
and find_backend(lines[line_index]) is None
|
||||
and not lines[line_index].startswith("else")
|
||||
):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 8):
|
||||
objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
type_hint_objects = {"none": objects}
|
||||
# Let's continue with backend-specific objects
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
# Check if the backend declaration is inside a try block:
|
||||
if _re_try.search(lines[line_index - 1]) is None:
|
||||
backend = None
|
||||
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Scroll until we hit the else block of try-except-else
|
||||
while _re_else.search(lines[line_index]) is None:
|
||||
line_index += 1
|
||||
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
type_hint_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
return import_dict_objects, type_hint_objects
|
||||
|
||||
|
||||
def analyze_results(import_dict_objects, type_hint_objects):
|
||||
"""
|
||||
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
|
||||
"""
|
||||
|
||||
def find_duplicates(seq):
|
||||
return [k for k, v in collections.Counter(seq).items() if v > 1]
|
||||
|
||||
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
|
||||
return ["Both sides of the init do not have the same backends!"]
|
||||
|
||||
errors = []
|
||||
for key in import_dict_objects.keys():
|
||||
duplicate_imports = find_duplicates(import_dict_objects[key])
|
||||
if duplicate_imports:
|
||||
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}")
|
||||
duplicate_type_hints = find_duplicates(type_hint_objects[key])
|
||||
if duplicate_type_hints:
|
||||
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}")
|
||||
|
||||
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])):
|
||||
name = "base imports" if key == "none" else f"{key} backend"
|
||||
errors.append(f"Differences for {name}:")
|
||||
for a in type_hint_objects[key]:
|
||||
if a not in import_dict_objects[key]:
|
||||
errors.append(f" {a} in TYPE_HINT but not in _import_structure.")
|
||||
for a in import_dict_objects[key]:
|
||||
if a not in type_hint_objects[key]:
|
||||
errors.append(f" {a} in _import_structure but not in TYPE_HINT.")
|
||||
return errors
|
||||
|
||||
|
||||
def check_all_inits():
|
||||
"""
|
||||
Check all inits in the transformers repo and raise an error if at least one does not define the same objects in
|
||||
both halves.
|
||||
"""
|
||||
failures = []
|
||||
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
if "__init__.py" in files:
|
||||
fname = os.path.join(root, "__init__.py")
|
||||
objects = parse_init(fname)
|
||||
if objects is not None:
|
||||
errors = analyze_results(*objects)
|
||||
if len(errors) > 0:
|
||||
errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}"
|
||||
failures.append("\n".join(errors))
|
||||
if len(failures) > 0:
|
||||
raise ValueError("\n\n".join(failures))
|
||||
|
||||
|
||||
def get_transformers_submodules():
|
||||
"""
|
||||
Returns the list of Transformers submodules.
|
||||
"""
|
||||
submodules = []
|
||||
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
for folder in directories:
|
||||
# Ignore private modules
|
||||
if folder.startswith("_"):
|
||||
directories.remove(folder)
|
||||
continue
|
||||
# Ignore leftovers from branches (empty folders apart from pycache)
|
||||
if len(list((Path(path) / folder).glob("*.py"))) == 0:
|
||||
continue
|
||||
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS))
|
||||
submodule = short_path.replace(os.path.sep, ".")
|
||||
submodules.append(submodule)
|
||||
for fname in files:
|
||||
if fname == "__init__.py":
|
||||
continue
|
||||
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
|
||||
submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
|
||||
if len(submodule.split(".")) == 1:
|
||||
submodules.append(submodule)
|
||||
return submodules
|
||||
|
||||
|
||||
IGNORE_SUBMODULES = [
|
||||
"convert_pytorch_checkpoint_to_tf2",
|
||||
"modeling_flax_pytorch_utils",
|
||||
]
|
||||
|
||||
|
||||
def check_submodules():
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
module_not_registered = [
|
||||
module
|
||||
for module in get_transformers_submodules()
|
||||
if module not in IGNORE_SUBMODULES and module not in transformers._import_structure.keys()
|
||||
]
|
||||
if len(module_not_registered) > 0:
|
||||
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered)
|
||||
raise ValueError(
|
||||
"The following submodules are not properly registed in the main init of Transformers:\n"
|
||||
f"{list_of_modules}\n"
|
||||
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_all_inits()
|
||||
check_submodules()
|
|
@ -0,0 +1,762 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.utils import ENV_VARS_TRUE_VALUES
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_repo.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
PATH_TO_TESTS = "tests"
|
||||
PATH_TO_DOC = "docs/source/en"
|
||||
|
||||
# Update this list with models that are supposed to be private.
|
||||
PRIVATE_MODELS = [
|
||||
"DPRSpanPredictor",
|
||||
"RealmBertModel",
|
||||
"T5Stack",
|
||||
"TFDPRSpanPredictor",
|
||||
]
|
||||
|
||||
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
||||
# Being in this list is an exception and should **not** be the rule.
|
||||
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for not tested
|
||||
"OPTDecoder", # Building part of bigger (tested) model.
|
||||
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
|
||||
"SegformerDecodeHead", # Building part of bigger (tested) model.
|
||||
"PLBartEncoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DetrEncoder", # Building part of bigger (tested) model.
|
||||
"DetrDecoder", # Building part of bigger (tested) model.
|
||||
"DetrDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"M2M100Encoder", # Building part of bigger (tested) model.
|
||||
"M2M100Decoder", # Building part of bigger (tested) model.
|
||||
"Speech2TextEncoder", # Building part of bigger (tested) model.
|
||||
"Speech2TextDecoder", # Building part of bigger (tested) model.
|
||||
"LEDEncoder", # Building part of bigger (tested) model.
|
||||
"LEDDecoder", # Building part of bigger (tested) model.
|
||||
"BartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BartEncoder", # Building part of bigger (tested) model.
|
||||
"BertLMHeadModel", # Needs to be setup as decoder.
|
||||
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
|
||||
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BlenderbotEncoder", # Building part of bigger (tested) model.
|
||||
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MBartEncoder", # Building part of bigger (tested) model.
|
||||
"MBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
|
||||
"MegatronBertEncoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"PegasusEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DPREncoder", # Building part of bigger (tested) model.
|
||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"RealmBertModel", # Building part of bigger (tested) model.
|
||||
"RealmReader", # Not regular model.
|
||||
"RealmScorer", # Not regular model.
|
||||
"RealmForOpenQA", # Not regular model.
|
||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
||||
"TFRobertaForMultipleChoice", # TODO: fix
|
||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
|
||||
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
|
||||
"OPTDecoderWrapper",
|
||||
]
|
||||
|
||||
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
|
||||
# trigger the common tests.
|
||||
TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"models/decision_transformer/test_modeling_decision_transformer.py",
|
||||
"models/camembert/test_modeling_camembert.py",
|
||||
"models/mt5/test_modeling_flax_mt5.py",
|
||||
"models/mbart/test_modeling_mbart.py",
|
||||
"models/mt5/test_modeling_mt5.py",
|
||||
"models/pegasus/test_modeling_pegasus.py",
|
||||
"models/camembert/test_modeling_tf_camembert.py",
|
||||
"models/mt5/test_modeling_tf_mt5.py",
|
||||
"models/xlm_roberta/test_modeling_tf_xlm_roberta.py",
|
||||
"models/xlm_roberta/test_modeling_flax_xlm_roberta.py",
|
||||
"models/xlm_prophetnet/test_modeling_xlm_prophetnet.py",
|
||||
"models/xlm_roberta/test_modeling_xlm_roberta.py",
|
||||
"models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
|
||||
"models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
|
||||
"models/decision_transformer/test_modeling_decision_transformer.py",
|
||||
]
|
||||
|
||||
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for model xxx mapping
|
||||
"DPTForDepthEstimation",
|
||||
"DecisionTransformerGPT2Model",
|
||||
"GLPNForDepthEstimation",
|
||||
"ViltForQuestionAnswering",
|
||||
"ViltForImagesAndTextClassification",
|
||||
"ViltForImageAndTextRetrieval",
|
||||
"ViltForMaskedLM",
|
||||
"XGLMEncoder",
|
||||
"XGLMDecoder",
|
||||
"XGLMDecoderWrapper",
|
||||
"PerceiverForMultimodalAutoencoding",
|
||||
"PerceiverForOpticalFlow",
|
||||
"SegformerDecodeHead",
|
||||
"FlaxBeitForMaskedImageModeling",
|
||||
"PLBartEncoder",
|
||||
"PLBartDecoder",
|
||||
"PLBartDecoderWrapper",
|
||||
"BeitForMaskedImageModeling",
|
||||
"CLIPTextModel",
|
||||
"CLIPVisionModel",
|
||||
"TFCLIPTextModel",
|
||||
"TFCLIPVisionModel",
|
||||
"FlaxCLIPTextModel",
|
||||
"FlaxCLIPVisionModel",
|
||||
"FlaxWav2Vec2ForCTC",
|
||||
"DetrForSegmentation",
|
||||
"DPRReader",
|
||||
"FlaubertForQuestionAnswering",
|
||||
"FlavaImageCodebook",
|
||||
"FlavaTextModel",
|
||||
"FlavaImageModel",
|
||||
"FlavaMultimodalModel",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"LukeForMaskedLM",
|
||||
"LukeForEntityClassification",
|
||||
"LukeForEntityPairClassification",
|
||||
"LukeForEntitySpanClassification",
|
||||
"OpenAIGPTDoubleHeadsModel",
|
||||
"RagModel",
|
||||
"RagSequenceForGeneration",
|
||||
"RagTokenForGeneration",
|
||||
"RealmEmbedder",
|
||||
"RealmForOpenQA",
|
||||
"RealmScorer",
|
||||
"RealmReader",
|
||||
"TFDPRReader",
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"TFRagModel",
|
||||
"TFRagSequenceForGeneration",
|
||||
"TFRagTokenForGeneration",
|
||||
"Wav2Vec2ForCTC",
|
||||
"HubertForCTC",
|
||||
"SEWForCTC",
|
||||
"SEWDForCTC",
|
||||
"XLMForQuestionAnswering",
|
||||
"XLNetForQuestionAnswering",
|
||||
"SeparableConv1D",
|
||||
"VisualBertForRegionToPhraseAlignment",
|
||||
"VisualBertForVisualReasoning",
|
||||
"VisualBertForQuestionAnswering",
|
||||
"VisualBertForMultipleChoice",
|
||||
"TFWav2Vec2ForCTC",
|
||||
"TFHubertForCTC",
|
||||
"MaskFormerForInstanceSegmentation",
|
||||
]
|
||||
|
||||
# Update this list for models that have multiple model types for the same
|
||||
# model doc
|
||||
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
|
||||
[
|
||||
("data2vec-text", "data2vec"),
|
||||
("data2vec-audio", "data2vec"),
|
||||
("data2vec-vision", "data2vec"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
|
||||
def check_model_list():
|
||||
"""Check the model list inside the transformers library."""
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
|
||||
_models = []
|
||||
for model in os.listdir(models_dir):
|
||||
model_dir = os.path.join(models_dir, model)
|
||||
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
|
||||
_models.append(model)
|
||||
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models = [model for model in dir(transformers.models) if not model.startswith("__")]
|
||||
|
||||
missing_models = sorted(list(set(_models).difference(models)))
|
||||
if missing_models:
|
||||
raise Exception(
|
||||
f"The following models should be included in {models_dir}/__init__.py: {','.join(missing_models)}."
|
||||
)
|
||||
|
||||
|
||||
# If some modeling modules should be ignored for all checks, they should be added in the nested list
|
||||
# _ignore_modules of this function.
|
||||
def get_model_modules():
|
||||
"""Get the model modules inside the transformers library."""
|
||||
_ignore_modules = [
|
||||
"modeling_auto",
|
||||
"modeling_encoder_decoder",
|
||||
"modeling_marian",
|
||||
"modeling_mmbt",
|
||||
"modeling_outputs",
|
||||
"modeling_retribert",
|
||||
"modeling_utils",
|
||||
"modeling_flax_auto",
|
||||
"modeling_flax_encoder_decoder",
|
||||
"modeling_flax_utils",
|
||||
"modeling_speech_encoder_decoder",
|
||||
"modeling_flax_speech_encoder_decoder",
|
||||
"modeling_flax_vision_encoder_decoder",
|
||||
"modeling_transfo_xl_utilities",
|
||||
"modeling_tf_auto",
|
||||
"modeling_tf_encoder_decoder",
|
||||
"modeling_tf_outputs",
|
||||
"modeling_tf_pytorch_utils",
|
||||
"modeling_tf_utils",
|
||||
"modeling_tf_transfo_xl_utilities",
|
||||
"modeling_tf_vision_encoder_decoder",
|
||||
"modeling_vision_encoder_decoder",
|
||||
]
|
||||
modules = []
|
||||
for model in dir(transformers.models):
|
||||
# There are some magic dunder attributes in the dir, we ignore them
|
||||
if not model.startswith("__"):
|
||||
model_module = getattr(transformers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
if inspect.ismodule(modeling_module):
|
||||
modules.append(modeling_module)
|
||||
return modules
|
||||
|
||||
|
||||
def get_models(module, include_pretrained=False):
|
||||
"""Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||
for attr_name in dir(module):
|
||||
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
||||
continue
|
||||
attr = getattr(module, attr_name)
|
||||
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
|
||||
models.append((attr_name, attr))
|
||||
return models
|
||||
|
||||
|
||||
def is_a_private_model(model):
|
||||
"""Returns True if the model should not be in the main init."""
|
||||
if model in PRIVATE_MODELS:
|
||||
return True
|
||||
|
||||
# Wrapper, Encoder and Decoder are all privates
|
||||
if model.endswith("Wrapper"):
|
||||
return True
|
||||
if model.endswith("Encoder"):
|
||||
return True
|
||||
if model.endswith("Decoder"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_models_are_in_init():
|
||||
"""Checks all models defined in the library are in the main init."""
|
||||
models_not_in_init = []
|
||||
dir_transformers = dir(transformers)
|
||||
for module in get_model_modules():
|
||||
models_not_in_init += [
|
||||
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
|
||||
]
|
||||
|
||||
# Remove private models
|
||||
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
|
||||
if len(models_not_in_init) > 0:
|
||||
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
|
||||
|
||||
|
||||
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
||||
# nested list _ignore_files of this function.
|
||||
def get_model_test_files():
|
||||
"""Get the model test files.
|
||||
|
||||
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
|
||||
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
|
||||
"""
|
||||
|
||||
_ignore_files = [
|
||||
"test_modeling_common",
|
||||
"test_modeling_encoder_decoder",
|
||||
"test_modeling_flax_encoder_decoder",
|
||||
"test_modeling_flax_speech_encoder_decoder",
|
||||
"test_modeling_marian",
|
||||
"test_modeling_tf_common",
|
||||
"test_modeling_tf_encoder_decoder",
|
||||
]
|
||||
test_files = []
|
||||
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
|
||||
model_test_root = os.path.join(PATH_TO_TESTS, "models")
|
||||
model_test_dirs = []
|
||||
for x in os.listdir(model_test_root):
|
||||
x = os.path.join(model_test_root, x)
|
||||
if os.path.isdir(x):
|
||||
model_test_dirs.append(x)
|
||||
|
||||
for target_dir in [PATH_TO_TESTS] + model_test_dirs:
|
||||
for file_or_dir in os.listdir(target_dir):
|
||||
path = os.path.join(target_dir, file_or_dir)
|
||||
if os.path.isfile(path):
|
||||
filename = os.path.split(path)[-1]
|
||||
if "test_modeling" in filename and not os.path.splitext(filename)[0] in _ignore_files:
|
||||
file = os.path.join(*path.split(os.sep)[1:])
|
||||
test_files.append(file)
|
||||
|
||||
return test_files
|
||||
|
||||
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
|
||||
# for the all_model_classes variable.
|
||||
def find_tested_models(test_file):
|
||||
"""Parse the content of test_file to detect what's in all_model_classes"""
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
|
||||
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
|
||||
content = f.read()
|
||||
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
||||
# Check with one less parenthesis as well
|
||||
all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
|
||||
if len(all_models) > 0:
|
||||
model_tested = []
|
||||
for entry in all_models:
|
||||
for line in entry.split(","):
|
||||
name = line.strip()
|
||||
if len(name) > 0:
|
||||
model_tested.append(name)
|
||||
return model_tested
|
||||
|
||||
|
||||
def check_models_are_tested(module, test_file):
|
||||
"""Check models defined in module are tested in test_file."""
|
||||
# XxxPreTrainedModel are not tested
|
||||
defined_models = get_models(module)
|
||||
tested_models = find_tested_models(test_file)
|
||||
if tested_models is None:
|
||||
if test_file.replace(os.path.sep, "/") in TEST_FILES_WITH_NO_COMMON_TESTS:
|
||||
return
|
||||
return [
|
||||
f"{test_file} should define `all_model_classes` to apply common tests to the models it tests. "
|
||||
+ "If this intentional, add the test filename to `TEST_FILES_WITH_NO_COMMON_TESTS` in the file "
|
||||
+ "`utils/check_repo.py`."
|
||||
]
|
||||
failures = []
|
||||
for model_name, _ in defined_models:
|
||||
if model_name not in tested_models and model_name not in IGNORE_NON_TESTED:
|
||||
failures.append(
|
||||
f"{model_name} is defined in {module.__name__} but is not tested in "
|
||||
+ f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file."
|
||||
+ "If common tests should not applied to that model, add its name to `IGNORE_NON_TESTED`"
|
||||
+ "in the file `utils/check_repo.py`."
|
||||
)
|
||||
return failures
|
||||
|
||||
|
||||
def check_all_models_are_tested():
|
||||
"""Check all models are properly tested."""
|
||||
modules = get_model_modules()
|
||||
test_files = get_model_test_files()
|
||||
failures = []
|
||||
for module in modules:
|
||||
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
|
||||
if len(test_file) == 0:
|
||||
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
|
||||
elif len(test_file) > 1:
|
||||
failures.append(f"{module.__name__} has several test files: {test_file}.")
|
||||
else:
|
||||
test_file = test_file[0]
|
||||
new_failures = check_models_are_tested(module, test_file)
|
||||
if new_failures is not None:
|
||||
failures += new_failures
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
def get_all_auto_configured_models():
|
||||
"""Return the list of all models in at least one auto class."""
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
if is_torch_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||
if is_tf_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||
if is_flax_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||
return [cls for cls in result]
|
||||
|
||||
|
||||
def ignore_unautoclassed(model_name):
|
||||
"""Rules to determine if `name` should be in an auto class."""
|
||||
# Special white list
|
||||
if model_name in IGNORE_NON_AUTO_CONFIGURED:
|
||||
return True
|
||||
# Encoder and Decoder should be ignored
|
||||
if "Encoder" in model_name or "Decoder" in model_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_models_are_auto_configured(module, all_auto_models):
|
||||
"""Check models defined in module are each in an auto class."""
|
||||
defined_models = get_models(module)
|
||||
failures = []
|
||||
for model_name, _ in defined_models:
|
||||
if model_name not in all_auto_models and not ignore_unautoclassed(model_name):
|
||||
failures.append(
|
||||
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
|
||||
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
|
||||
"`utils/check_repo.py`."
|
||||
)
|
||||
return failures
|
||||
|
||||
|
||||
def check_all_models_are_auto_configured():
|
||||
"""Check all models are each in an auto class."""
|
||||
missing_backends = []
|
||||
if not is_torch_available():
|
||||
missing_backends.append("PyTorch")
|
||||
if not is_tf_available():
|
||||
missing_backends.append("TensorFlow")
|
||||
if not is_flax_available():
|
||||
missing_backends.append("Flax")
|
||||
if len(missing_backends) > 0:
|
||||
missing = ", ".join(missing_backends)
|
||||
if os.getenv("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
||||
raise Exception(
|
||||
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
|
||||
f"Transformers repo, the following are missing: {missing}."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
|
||||
f"Transformers repo, the following are missing: {missing}. While it's probably fine as long as you "
|
||||
"didn't make any change in one of those backends modeling files, you should probably execute the "
|
||||
"command above to be on the safe side."
|
||||
)
|
||||
modules = get_model_modules()
|
||||
all_auto_models = get_all_auto_configured_models()
|
||||
failures = []
|
||||
for module in modules:
|
||||
new_failures = check_models_are_auto_configured(module, all_auto_models)
|
||||
if new_failures is not None:
|
||||
failures += new_failures
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||
|
||||
|
||||
def check_decorator_order(filename):
|
||||
"""Check that in the test file `filename` the slow decorator is always last."""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
decorator_before = None
|
||||
errors = []
|
||||
for i, line in enumerate(lines):
|
||||
search = _re_decorator.search(line)
|
||||
if search is not None:
|
||||
decorator_name = search.groups()[0]
|
||||
if decorator_before is not None and decorator_name.startswith("parameterized"):
|
||||
errors.append(i)
|
||||
decorator_before = decorator_name
|
||||
elif decorator_before is not None:
|
||||
decorator_before = None
|
||||
return errors
|
||||
|
||||
|
||||
def check_all_decorator_order():
|
||||
"""Check that in all test files, the slow decorator is always last."""
|
||||
errors = []
|
||||
for fname in os.listdir(PATH_TO_TESTS):
|
||||
if fname.endswith(".py"):
|
||||
filename = os.path.join(PATH_TO_TESTS, fname)
|
||||
new_errors = check_decorator_order(filename)
|
||||
errors += [f"- {filename}, line {i}" for i in new_errors]
|
||||
if len(errors) > 0:
|
||||
msg = "\n".join(errors)
|
||||
raise ValueError(
|
||||
"The parameterized decorator (and its variants) should always be first, but this is not the case in the"
|
||||
f" following files:\n{msg}"
|
||||
)
|
||||
|
||||
|
||||
def find_all_documented_objects():
|
||||
"""Parse the content of all doc files to detect which classes and functions it documents"""
|
||||
documented_obj = []
|
||||
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
|
||||
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
content = f.read()
|
||||
raw_doc_objs = re.findall(r"(?:autoclass|autofunction):: transformers.(\S+)\s+", content)
|
||||
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
|
||||
for doc_file in Path(PATH_TO_DOC).glob("**/*.mdx"):
|
||||
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
content = f.read()
|
||||
raw_doc_objs = re.findall("\[\[autodoc\]\]\s+(\S+)\s+", content)
|
||||
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
|
||||
return documented_obj
|
||||
|
||||
|
||||
# One good reason for not being documented is to be deprecated. Put in this list deprecated objects.
|
||||
DEPRECATED_OBJECTS = [
|
||||
"AutoModelWithLMHead",
|
||||
"BartPretrainedModel",
|
||||
"DataCollator",
|
||||
"DataCollatorForSOP",
|
||||
"GlueDataset",
|
||||
"GlueDataTrainingArguments",
|
||||
"LineByLineTextDataset",
|
||||
"LineByLineWithRefDataset",
|
||||
"LineByLineWithSOPTextDataset",
|
||||
"PretrainedBartModel",
|
||||
"PretrainedFSMTModel",
|
||||
"SingleSentenceClassificationProcessor",
|
||||
"SquadDataTrainingArguments",
|
||||
"SquadDataset",
|
||||
"SquadExample",
|
||||
"SquadFeatures",
|
||||
"SquadV1Processor",
|
||||
"SquadV2Processor",
|
||||
"TFAutoModelWithLMHead",
|
||||
"TFBartPretrainedModel",
|
||||
"TextDataset",
|
||||
"TextDatasetForNextSentencePrediction",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2Tokenizer",
|
||||
"glue_compute_metrics",
|
||||
"glue_convert_examples_to_features",
|
||||
"glue_output_modes",
|
||||
"glue_processors",
|
||||
"glue_tasks_num_labels",
|
||||
"squad_convert_examples_to_features",
|
||||
"xnli_compute_metrics",
|
||||
"xnli_output_modes",
|
||||
"xnli_processors",
|
||||
"xnli_tasks_num_labels",
|
||||
"TFTrainer",
|
||||
"TFTrainingArguments",
|
||||
]
|
||||
|
||||
# Exceptionally, some objects should not be documented after all rules passed.
|
||||
# ONLY PUT SOMETHING IN THIS LIST AS A LAST RESORT!
|
||||
UNDOCUMENTED_OBJECTS = [
|
||||
"AddedToken", # This is a tokenizers class.
|
||||
"BasicTokenizer", # Internal, should never have been in the main init.
|
||||
"CharacterTokenizer", # Internal, should never have been in the main init.
|
||||
"DPRPretrainedReader", # Like an Encoder.
|
||||
"DummyObject", # Just picked by mistake sometimes.
|
||||
"MecabTokenizer", # Internal, should never have been in the main init.
|
||||
"ModelCard", # Internal type.
|
||||
"SqueezeBertModule", # Internal building block (should have been called SqueezeBertLayer)
|
||||
"TFDPRPretrainedReader", # Like an Encoder.
|
||||
"TransfoXLCorpus", # Internal type.
|
||||
"WordpieceTokenizer", # Internal, should never have been in the main init.
|
||||
"absl", # External module
|
||||
"add_end_docstrings", # Internal, should never have been in the main init.
|
||||
"add_start_docstrings", # Internal, should never have been in the main init.
|
||||
"cached_path", # Internal used for downloading models.
|
||||
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
|
||||
"logger", # Internal logger
|
||||
"logging", # External module
|
||||
"requires_backends", # Internal function
|
||||
]
|
||||
|
||||
# This list should be empty. Objects in it should get their own doc page.
|
||||
SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
# Benchmarks
|
||||
"PyTorchBenchmark",
|
||||
"PyTorchBenchmarkArguments",
|
||||
"TensorFlowBenchmark",
|
||||
"TensorFlowBenchmarkArguments",
|
||||
]
|
||||
|
||||
|
||||
def ignore_undocumented(name):
|
||||
"""Rules to determine if `name` should be undocumented."""
|
||||
# NOT DOCUMENTED ON PURPOSE.
|
||||
# Constants uppercase are not documented.
|
||||
if name.isupper():
|
||||
return True
|
||||
# PreTrainedModels / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
|
||||
if (
|
||||
name.endswith("PreTrainedModel")
|
||||
or name.endswith("Decoder")
|
||||
or name.endswith("Encoder")
|
||||
or name.endswith("Layer")
|
||||
or name.endswith("Embeddings")
|
||||
or name.endswith("Attention")
|
||||
):
|
||||
return True
|
||||
# Submodules are not documented.
|
||||
if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile(
|
||||
os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py")
|
||||
):
|
||||
return True
|
||||
# All load functions are not documented.
|
||||
if name.startswith("load_tf") or name.startswith("load_pytorch"):
|
||||
return True
|
||||
# is_xxx_available functions are not documented.
|
||||
if name.startswith("is_") and name.endswith("_available"):
|
||||
return True
|
||||
# Deprecated objects are not documented.
|
||||
if name in DEPRECATED_OBJECTS or name in UNDOCUMENTED_OBJECTS:
|
||||
return True
|
||||
# MMBT model does not really work.
|
||||
if name.startswith("MMBT"):
|
||||
return True
|
||||
if name in SHOULD_HAVE_THEIR_OWN_PAGE:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_all_objects_are_documented():
|
||||
"""Check all models are properly documented."""
|
||||
documented_objs = find_all_documented_objects()
|
||||
modules = transformers._modules
|
||||
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
||||
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
|
||||
if len(undocumented_objs) > 0:
|
||||
raise Exception(
|
||||
"The following objects are in the public init so should be documented:\n - "
|
||||
+ "\n - ".join(undocumented_objs)
|
||||
)
|
||||
check_docstrings_are_in_md()
|
||||
check_model_type_doc_match()
|
||||
|
||||
|
||||
def check_model_type_doc_match():
|
||||
"""Check all doc pages have a corresponding model type."""
|
||||
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
|
||||
model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")]
|
||||
|
||||
model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
||||
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
|
||||
|
||||
errors = []
|
||||
for m in model_docs:
|
||||
if m not in model_types and m != "auto":
|
||||
close_matches = get_close_matches(m, model_types)
|
||||
error_message = f"{m} is not a proper model identifier."
|
||||
if len(close_matches) > 0:
|
||||
close_matches = "/".join(close_matches)
|
||||
error_message += f" Did you mean {close_matches}?"
|
||||
errors.append(error_message)
|
||||
|
||||
if len(errors) > 0:
|
||||
raise ValueError(
|
||||
"Some model doc pages do not match any existing model type:\n"
|
||||
+ "\n".join(errors)
|
||||
+ "\nYou can add any missing model type to the `MODEL_NAMES_MAPPING` constant in "
|
||||
"models/auto/configuration_auto.py."
|
||||
)
|
||||
|
||||
|
||||
# Re pattern to catch :obj:`xx`, :class:`xx`, :func:`xx` or :meth:`xx`.
|
||||
_re_rst_special_words = re.compile(r":(?:obj|func|class|meth):`([^`]+)`")
|
||||
# Re pattern to catch things between double backquotes.
|
||||
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
|
||||
# Re pattern to catch example introduction.
|
||||
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)
|
||||
|
||||
|
||||
def is_rst_docstring(docstring):
|
||||
"""
|
||||
Returns `True` if `docstring` is written in rst.
|
||||
"""
|
||||
if _re_rst_special_words.search(docstring) is not None:
|
||||
return True
|
||||
if _re_double_backquotes.search(docstring) is not None:
|
||||
return True
|
||||
if _re_rst_example.search(docstring) is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_docstrings_are_in_md():
|
||||
"""Check all docstrings are in md"""
|
||||
files_with_rst = []
|
||||
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
|
||||
with open(file, "r") as f:
|
||||
code = f.read()
|
||||
docstrings = code.split('"""')
|
||||
|
||||
for idx, docstring in enumerate(docstrings):
|
||||
if idx % 2 == 0 or not is_rst_docstring(docstring):
|
||||
continue
|
||||
files_with_rst.append(file)
|
||||
break
|
||||
|
||||
if len(files_with_rst) > 0:
|
||||
raise ValueError(
|
||||
"The following files have docstrings written in rst:\n"
|
||||
+ "\n".join([f"- {f}" for f in files_with_rst])
|
||||
+ "\nTo fix this run `doc-builder convert path_to_py_file` after installing `doc-builder`\n"
|
||||
"(`pip install git+https://github.com/huggingface/doc-builder`)"
|
||||
)
|
||||
|
||||
|
||||
def check_repo_quality():
|
||||
"""Check all models are properly tested and documented."""
|
||||
print("Checking all models are included.")
|
||||
check_model_list()
|
||||
print("Checking all models are public.")
|
||||
check_models_are_in_init()
|
||||
print("Checking all models are properly tested.")
|
||||
check_all_decorator_order()
|
||||
check_all_models_are_tested()
|
||||
print("Checking all objects are properly documented.")
|
||||
check_all_objects_are_documented()
|
||||
print("Checking all models are in at least one auto class.")
|
||||
check_all_models_are_auto_configured()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_repo_quality()
|
|
@ -0,0 +1,232 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_table.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
PATH_TO_DOCS = "docs/source/en"
|
||||
REPO_PATH = "."
|
||||
|
||||
|
||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
"""
|
||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
||||
lines.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start prompt.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
end_index = start_index
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
end_index += 1
|
||||
end_index -= 1
|
||||
|
||||
while len(lines[start_index]) <= 1:
|
||||
start_index += 1
|
||||
while len(lines[end_index]) <= 1:
|
||||
end_index -= 1
|
||||
end_index += 1
|
||||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||
|
||||
|
||||
# Add here suffixes that are used to identify models, seperated by |
|
||||
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
|
||||
# Regexes that match TF/Flax/PT model names.
|
||||
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
|
||||
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
||||
submodule_search_locations=[TRANSFORMERS_PATH],
|
||||
)
|
||||
transformers_module = spec.loader.load_module()
|
||||
|
||||
|
||||
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||
def camel_case_split(identifier):
|
||||
"Split a camelcased `identifier` into words."
|
||||
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
||||
return [m.group(0) for m in matches]
|
||||
|
||||
|
||||
def _center_text(text, width):
|
||||
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
||||
left_indent = (width - text_length) // 2
|
||||
right_indent = width - text_length - left_indent
|
||||
return " " * left_indent + text + " " * right_indent
|
||||
|
||||
|
||||
def get_model_table_from_auto_modules():
|
||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||
# Dictionary model names to config.
|
||||
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
model_name_to_config = {
|
||||
name: config_maping_names[code]
|
||||
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
|
||||
if code in config_maping_names
|
||||
}
|
||||
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
||||
|
||||
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
||||
slow_tokenizers = collections.defaultdict(bool)
|
||||
fast_tokenizers = collections.defaultdict(bool)
|
||||
pt_models = collections.defaultdict(bool)
|
||||
tf_models = collections.defaultdict(bool)
|
||||
flax_models = collections.defaultdict(bool)
|
||||
|
||||
# Let's lookup through all transformers object (once).
|
||||
for attr_name in dir(transformers_module):
|
||||
lookup_dict = None
|
||||
if attr_name.endswith("Tokenizer"):
|
||||
lookup_dict = slow_tokenizers
|
||||
attr_name = attr_name[:-9]
|
||||
elif attr_name.endswith("TokenizerFast"):
|
||||
lookup_dict = fast_tokenizers
|
||||
attr_name = attr_name[:-13]
|
||||
elif _re_tf_models.match(attr_name) is not None:
|
||||
lookup_dict = tf_models
|
||||
attr_name = _re_tf_models.match(attr_name).groups()[0]
|
||||
elif _re_flax_models.match(attr_name) is not None:
|
||||
lookup_dict = flax_models
|
||||
attr_name = _re_flax_models.match(attr_name).groups()[0]
|
||||
elif _re_pt_models.match(attr_name) is not None:
|
||||
lookup_dict = pt_models
|
||||
attr_name = _re_pt_models.match(attr_name).groups()[0]
|
||||
|
||||
if lookup_dict is not None:
|
||||
while len(attr_name) > 0:
|
||||
if attr_name in model_name_to_prefix.values():
|
||||
lookup_dict[attr_name] = True
|
||||
break
|
||||
# Try again after removing the last word in the name
|
||||
attr_name = "".join(camel_case_split(attr_name)[:-1])
|
||||
|
||||
# Let's build that table!
|
||||
model_names = list(model_name_to_config.keys())
|
||||
model_names.sort(key=str.lower)
|
||||
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
|
||||
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
|
||||
widths = [len(c) + 2 for c in columns]
|
||||
widths[0] = max([len(name) for name in model_names]) + 2
|
||||
|
||||
# Build the table per se
|
||||
table = "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
|
||||
# Use ":-----:" format to center-aligned table cell texts
|
||||
table += "|" + "|".join([":" + "-" * (w - 2) + ":" for w in widths]) + "|\n"
|
||||
|
||||
check = {True: "✅", False: "❌"}
|
||||
for name in model_names:
|
||||
prefix = model_name_to_prefix[name]
|
||||
line = [
|
||||
name,
|
||||
check[slow_tokenizers[prefix]],
|
||||
check[fast_tokenizers[prefix]],
|
||||
check[pt_models[prefix]],
|
||||
check[tf_models[prefix]],
|
||||
check[flax_models[prefix]],
|
||||
]
|
||||
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
|
||||
return table
|
||||
|
||||
|
||||
def check_model_table(overwrite=False):
|
||||
"""Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
|
||||
current_table, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.mdx"),
|
||||
start_prompt="<!--This table is updated automatically from the auto modules",
|
||||
end_prompt="<!-- End table-->",
|
||||
)
|
||||
new_table = get_model_table_from_auto_modules()
|
||||
|
||||
if current_table != new_table:
|
||||
if overwrite:
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model table in the `index.mdx` has not been updated. Run `make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
def has_onnx(model_type):
|
||||
"""
|
||||
Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not.
|
||||
"""
|
||||
config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
if model_type not in config_mapping:
|
||||
return False
|
||||
config = config_mapping[model_type]
|
||||
config_module = config.__module__
|
||||
module = transformers_module
|
||||
for part in config_module.split(".")[1:]:
|
||||
module = getattr(module, part)
|
||||
config_name = config.__name__
|
||||
onnx_config_name = config_name.replace("Config", "OnnxConfig")
|
||||
return hasattr(module, onnx_config_name)
|
||||
|
||||
|
||||
def get_onnx_model_list():
|
||||
"""
|
||||
Return the list of models supporting ONNX.
|
||||
"""
|
||||
config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
model_names = config_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING
|
||||
onnx_model_types = [model_type for model_type in config_mapping.keys() if has_onnx(model_type)]
|
||||
onnx_model_names = [model_names[model_type] for model_type in onnx_model_types]
|
||||
onnx_model_names.sort(key=lambda x: x.upper())
|
||||
return "\n".join([f"- {name}" for name in onnx_model_names]) + "\n"
|
||||
|
||||
|
||||
def check_onnx_model_list(overwrite=False):
|
||||
"""Check the model list in the serialization.mdx is consistent with the state of the lib and maybe `overwrite`."""
|
||||
current_list, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"),
|
||||
start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->",
|
||||
end_prompt="In the next two sections, we'll show you how to:",
|
||||
)
|
||||
new_list = get_onnx_model_list()
|
||||
|
||||
if current_list != new_list:
|
||||
if overwrite:
|
||||
with open(os.path.join(PATH_TO_DOCS, "serialization.mdx"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [new_list] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError("The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_model_table(args.fix_and_overwrite)
|
||||
check_onnx_model_list(args.fix_and_overwrite)
|
|
@ -0,0 +1,101 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
REPO_PATH = "."
|
||||
|
||||
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
|
||||
INTERNAL_OPS = [
|
||||
"Assert",
|
||||
"AssignVariableOp",
|
||||
"EmptyTensorList",
|
||||
"MergeV2Checkpoints",
|
||||
"ReadVariableOp",
|
||||
"ResourceGather",
|
||||
"RestoreV2",
|
||||
"SaveV2",
|
||||
"ShardedFilename",
|
||||
"StatefulPartitionedCall",
|
||||
"StaticRegexFullMatch",
|
||||
"VarHandleOp",
|
||||
]
|
||||
|
||||
|
||||
def onnx_compliancy(saved_model_path, strict, opset):
|
||||
saved_model = SavedModel()
|
||||
onnx_ops = []
|
||||
|
||||
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
|
||||
onnx_opsets = json.load(f)["opsets"]
|
||||
|
||||
for i in range(1, opset + 1):
|
||||
onnx_ops.extend(onnx_opsets[str(i)])
|
||||
|
||||
with open(saved_model_path, "rb") as f:
|
||||
saved_model.ParseFromString(f.read())
|
||||
|
||||
model_op_names = set()
|
||||
|
||||
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
|
||||
for meta_graph in saved_model.meta_graphs:
|
||||
# Add operations in the graph definition
|
||||
model_op_names.update(node.op for node in meta_graph.graph_def.node)
|
||||
|
||||
# Go through the functions in the graph definition
|
||||
for func in meta_graph.graph_def.library.function:
|
||||
# Add operations in each function
|
||||
model_op_names.update(node.op for node in func.node_def)
|
||||
|
||||
# Convert to list, sorted if you want
|
||||
model_op_names = sorted(model_op_names)
|
||||
incompatible_ops = []
|
||||
|
||||
for op in model_op_names:
|
||||
if op not in onnx_ops and op not in INTERNAL_OPS:
|
||||
incompatible_ops.append(op)
|
||||
|
||||
if strict and len(incompatible_ops) > 0:
|
||||
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
|
||||
elif len(incompatible_ops) > 0:
|
||||
print(f"Found the following incompatible ops for the opset {opset}:")
|
||||
print(*incompatible_ops, sep="\n")
|
||||
else:
|
||||
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
|
||||
parser.add_argument(
|
||||
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.framework == "onnx":
|
||||
onnx_compliancy(args.saved_model_path, args.strict, args.opset)
|
Loading…
Reference in New Issue