upload some cleaning tools

This commit is contained in:
Patrick von Platen 2022-05-31 10:17:19 +02:00
parent d849816659
commit 95f4256fc9
14 changed files with 2618 additions and 3 deletions

166
.gitignore vendored Normal file
View File

@ -0,0 +1,166 @@
# Initially taken from Github's Python gitignore file
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# tests and logs
tests/fixtures/cached_*_text.txt
logs/
lightning_logs/
lang_code_data/
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# vscode
.vs
.vscode
# Pycharm
.idea
# TF code
tensorflow_code
# Models
proc_data
# examples
runs
/runs_old
/wandb
/examples/runs
/examples/**/*.args
/examples/rag/sweep
# data
/data
serialization_dir
# emacs
*.*~
debug.env
# vim
.*.swp
#ctags
tags
# pre-commit
.pre-commit*
# .lock
*.lock
# DS_Store (MacOS)
.DS_Store

109
Makefile Normal file
View File

@ -0,0 +1,109 @@
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src
check_dirs := models tests src utils
modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
@if test -n "$(modified_py_files)"; then \
echo "Checking/fixing $(modified_py_files)"; \
black --preview $(modified_py_files); \
isort $(modified_py_files); \
flake8 $(modified_py_files); \
else \
echo "No library .py files were modified"; \
fi
# Update src/diffusers/dependency_versions_table.py
deps_table_update:
@python setup.py deps_table_update
deps_table_check_updated:
@md5sum src/diffusers/dependency_versions_table.py > md5sum.saved
@python setup.py deps_table_update
@md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1)
@rm md5sum.saved
# autogenerating code
autogenerate_code: deps_table_update
# Check that the repo is in a good state
repo-consistency:
python utils/check_copies.py
python utils/check_table.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_inits.py
python utils/check_config_docstrings.py
python utils/tests_fetcher.py --sanity_check
# this target runs checks on all files
quality:
black --check --preview $(check_dirs)
isort --check-only $(check_dirs)
python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
flake8 $(check_dirs)
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
# Format source code automatically and check is there are any problems left that need manual fixing
extra_style_checks:
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
# this target runs checks on all files and potentially modifies some of them
style:
black --preview $(check_dirs)
isort $(check_dirs)
${MAKE} autogenerate_code
${MAKE} extra_style_checks
# Super fast fix and check target that only works on relevant modified files since the branch was made
fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
# Make marked copies of snippets of codes conform to the original
fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
# Run tests for the library
test:
python -m pytest -n auto --dist=loadfile -s -v ./tests/
# Run tests for examples
test-examples:
python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/
# Run tests for SageMaker DLC release
test-sagemaker: # install sagemaker dependencies in advance with pip install .[sagemaker]
TEST_SAGEMAKER=True python -m pytest -n auto -s -v ./tests/sagemaker
# Release stuff
pre-release:
python utils/release.py
pre-patch:
python utils/release.py --patch
post-release:
python utils/release.py --post_release
post-patch:
python utils/release.py --post_release --patch

1
md5sum.saved Normal file
View File

@ -0,0 +1 @@
ce075df80e7ba2391d63d026be165c15 src/diffusers/dependency_versions_table.py

161
setup.py
View File

@ -12,8 +12,164 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from setuptools import setup """
from setuptools import find_packages Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py
To create the package for pypi.
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
documentation.
If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
for the post-release and run `make fix-copies` on the main branch as well.
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
3. Unpin specific versions from setup.py that use a git install.
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
message: "Release: <VERSION>" and push.
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
6. Add a tag in git to mark the release: "git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi' "
Push the tag to git: git push --tags origin v<RELEASE>-release
7. Build both the sources and the wheel. Do not change anything in setup.py between
creating the wheel and the source distribution (obviously).
For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
(this will build a wheel for the python version you use to build it).
For the sources, run: "python setup.py sdist"
You should now have a /dist directory with both .whl and .tar.gz source versions.
8. Check that everything looks correct by uploading the package to the pypi test server:
twine upload dist/* -r pypitest
(pypi suggest using twine as other methods upload files via plaintext.)
You may have to specify the repository url, use the following command then:
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
Check that you can install it in a virtualenv by running:
pip install -i https://testpypi.python.org/pypi transformers
Check you can run the following commands:
python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from transformers import *"
9. Upload the final version to actual pypi:
twine upload dist/* -r pypi
10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
you need to go back to main before executing this.
"""
import re
from distutils.core import Command
from setuptools import find_packages, setup
# IMPORTANT:
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [
"Pillow",
"accelerate>=0.9.0",
"black~=22.0,>=22.3",
"codecarbon==1.2.0",
"dataclasses",
"datasets",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.1.0,<1.0",
"importlib_metadata",
"isort>=5.5.4",
"numpy>=1.17",
"pytest",
"pytest-timeout",
"pytest-xdist",
"python>=3.7.0",
"regex!=2019.12.17",
"requests",
"sagemaker>=2.31.0",
"tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch>=1.4",
"torchaudio",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
"uvicorn",
]
# this is a lookup table with items like:
#
# tokenizers: "tokenizers==0.9.4"
# packaging: "packaging"
#
# some of the values are versioned whereas others aren't.
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)}
# since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from
# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:
#
# python -c 'import sys; from diffusers.dependency_versions_table import deps; \
# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets
#
# Just pass the desired package names to that script as it's shown with 2 packages above.
#
# If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above
#
# You can then feed this for example to `pip`:
#
# pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \
# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets)
#
def deps_list(*pkgs):
return [deps[pkg] for pkg in pkgs]
class DepsTableUpdateCommand(Command):
"""
A custom distutils command that updates the dependency table.
usage: python setup.py deps_table_update
"""
description = "build runtime dependency table"
user_options = [
# format: (long option, short option, description).
("dep-table-update", None, "updates src/diffusers/dependency_versions_table.py"),
]
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()])
content = [
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
"# 1. modify the `_deps` dict in setup.py",
"# 2. run `make deps_table_update``",
"deps = {",
entries,
"}",
"",
]
target = "src/diffusers/dependency_versions_table.py"
print(f"updating {target}")
with open(target, "w", encoding="utf-8", newline="\n") as f:
f.write("\n".join(content))
extras = {}
extras = {} extras = {}
extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"] extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
@ -61,6 +217,7 @@ setup(
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
], ],
cmdclass={"deps_table_update": DepsTableUpdateCommand},
) )
# Release checklist # Release checklist

View File

@ -0,0 +1,47 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from .dependency_versions_table import deps
from .utils.versions import require_version, require_version_core
# define which module versions we always want to check at run time
# (usually the ones defined in `install_requires` in setup.py)
#
# order specific notes:
# - tqdm must be checked before tokenizers
pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
if sys.version_info < (3, 7):
pkgs_to_check_at_runtime.append("dataclasses")
if sys.version_info < (3, 8):
pkgs_to_check_at_runtime.append("importlib_metadata")
for pkg in pkgs_to_check_at_runtime:
if pkg in deps:
if pkg == "tokenizers":
# must be loaded here, or else tqdm check may fail
from .utils import is_tokenizers_available
if not is_tokenizers_available():
continue # not required, check version only if installed
require_version_core(deps[pkg])
else:
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
def dep_version_check(pkg, hint=None):
require_version(deps[pkg], hint)

View File

@ -0,0 +1,31 @@
# THIS FILE HAS BEEN AUTOGENERATED. To update:
# 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.9.0",
"black": "black~=22.0,>=22.3",
"codecarbon": "codecarbon==1.2.0",
"dataclasses": "dataclasses",
"datasets": "datasets",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"numpy": "numpy>=1.17",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.7.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"sagemaker": "sagemaker>=2.31.0",
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch": "torch>=1.4",
"torchaudio": "torchaudio",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
"uvicorn": "uvicorn",
}

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class UNetModel:
class UNetModel:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
print("I can diffuse!") print("I can diffuse!")

View File

@ -0,0 +1,84 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_config_docstrings.py
PATH_TO_TRANSFORMERS = "src/transformers"
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
submodule_search_locations=[PATH_TO_TRANSFORMERS],
)
transformers = spec.loader.load_module()
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
# For example, `[bert-base-uncased](https://huggingface.co/bert-base-uncased)`
_re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"CLIPConfig",
"DecisionTransformerConfig",
"EncoderDecoderConfig",
"RagConfig",
"SpeechEncoderDecoderConfig",
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
}
def check_config_docstrings_have_checkpoints():
configs_without_checkpoint = []
for config_class in list(CONFIG_MAPPING.values()):
checkpoint_found = False
# source code of `config_class`
config_source = inspect.getsource(config_class)
checkpoints = _re_checkpoint.findall(config_source)
for checkpoint in checkpoints:
# Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link.
# For example, `('bert-base-uncased', 'https://huggingface.co/bert-base-uncased')`
ckpt_name, ckpt_link = checkpoint
# verify the checkpoint name corresponds to the checkpoint link
ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
if ckpt_link == ckpt_link_from_name:
checkpoint_found = True
break
name = config_class.__name__
if not checkpoint_found and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
configs_without_checkpoint.append(name)
if len(configs_without_checkpoint) > 0:
message = "\n".join(sorted(configs_without_checkpoint))
raise ValueError(f"The following configurations don't contain any valid checkpoint:\n{message}")
if __name__ == "__main__":
check_config_docstrings_have_checkpoints()

458
utils/check_copies.py Normal file
View File

@ -0,0 +1,458 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import os
import re
import black
from doc_builder.style_doc import style_docstrings_in_code
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
TRANSFORMERS_PATH = "src/transformers"
PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "."
# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
FULL_COPIES = {
"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py",
"examples/flax/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py",
}
LOCALIZED_READMES = {
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
"README.md": {
"start_prompt": "🤗 Transformers currently provides the following architectures",
"end_prompt": "1. Want to contribute a new model?",
"format_model_list": (
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
" {paper_authors}.{supplements}"
),
},
"README_zh-hans.md": {
"start_prompt": "🤗 Transformers 目前支持如下的架构",
"end_prompt": "1. 想要贡献新的模型?",
"format_model_list": (
"**[{title}]({model_link})** (来自 {paper_affiliations}) 伴随论文 {paper_title_link}{paper_authors}"
" 发布。{supplements}"
),
},
"README_zh-hant.md": {
"start_prompt": "🤗 Transformers 目前支援以下的架構",
"end_prompt": "1. 想要貢獻新的模型?",
"format_model_list": (
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
" {paper_authors}.{supplements}"
),
},
"README_ko.md": {
"start_prompt": "🤗 Transformers는 다음 모델들을 제공합니다",
"end_prompt": "1. 새로운 모델을 올리고 싶나요?",
"format_model_list": (
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
" {paper_authors}.{supplements}"
),
},
}
def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_transformers(object_name):
"""Find and return the code source code of `object_name`."""
parts = object_name.split(".")
i = 0
# First let's find the module where our object lives.
module = parts[i]
while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
i += 1
if i < len(parts):
module = os.path.join(module, parts[i])
if i >= len(parts):
raise ValueError(
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
)
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Now let's find the class / func in the code!
indent = ""
line_index = 0
for name in parts[i + 1 :]:
while (
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
):
line_index += 1
indent += " "
line_index += 1
if line_index >= len(lines):
raise ValueError(f" {object_name} does not match any function or class in {module}.")
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
start_index = line_index
while line_index < len(lines) and _should_continue(lines[line_index], indent):
line_index += 1
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
code_lines = lines[start_index:line_index]
return "".join(code_lines)
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
def get_indent(code):
lines = code.split("\n")
idx = 0
while idx < len(lines) and len(lines[idx]) == 0:
idx += 1
if idx < len(lines):
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
return ""
def blackify(code):
"""
Applies the black part of our `make style` command to `code`.
"""
has_indent = len(get_indent(code)) > 0
if has_indent:
code = f"class Bla:\n{code}"
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119, preview=True)
result = black.format_str(code, mode=mode)
result, _ = style_docstrings_in_code(result)
return result[len("class Bla:\n") :] if has_indent else result
def is_copy_consistent(filename, overwrite=False):
"""
Check if the code commented as a copy in `filename` matches the original.
Return the differences or overwrites the content depending on `overwrite`.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
diffs = []
line_index = 0
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
while line_index < len(lines):
search = _re_copy_warning.search(lines[line_index])
if search is None:
line_index += 1
continue
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
theoretical_code = find_code_in_transformers(object_name)
theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
indent = theoretical_indent
line_index = start_index
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
should_continue = True
while line_index < len(lines) and should_continue:
line_index += 1
if line_index >= len(lines):
break
line = lines[line_index]
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
observed_code_lines = lines[start_index:line_index]
observed_code = "".join(observed_code_lines)
# Before comparing, use the `replace_pattern` on the original code.
if len(replace_pattern) > 0:
patterns = replace_pattern.replace("with", "").split(",")
patterns = [_re_replace_pattern.search(p) for p in patterns]
for pattern in patterns:
if pattern is None:
continue
obj1, obj2, option = pattern.groups()
theoretical_code = re.sub(obj1, obj2, theoretical_code)
if option.strip() == "all-casing":
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
# Blackify after replacement. To be able to do that, we need the header (class or function definition)
# from the previous line
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
# Test for a diff and act accordingly.
if observed_code != theoretical_code:
diffs.append([object_name, start_index])
if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
line_index = start_index + 1
if overwrite and len(diffs) > 0:
# Warn the user a file has been modified.
print(f"Detected changes, rewriting {filename}.")
with open(filename, "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines)
return diffs
def check_copies(overwrite: bool = False):
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = []
for filename in all_files:
new_diffs = is_copy_consistent(filename, overwrite)
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
if not overwrite and len(diffs) > 0:
diff = "\n".join(diffs)
raise Exception(
"Found the following copy inconsistencies:\n"
+ diff
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
)
check_model_list_copy(overwrite=overwrite)
def check_full_copies(overwrite: bool = False):
diffs = []
for target, source in FULL_COPIES.items():
with open(source, "r", encoding="utf-8") as f:
source_code = f.read()
with open(target, "r", encoding="utf-8") as f:
target_code = f.read()
if source_code != target_code:
if overwrite:
with open(target, "w", encoding="utf-8") as f:
print(f"Replacing the content of {target} by the one of {source}.")
f.write(source_code)
else:
diffs.append(f"- {target}: copy does not match {source}.")
if not overwrite and len(diffs) > 0:
diff = "\n".join(diffs)
raise Exception(
"Found the following copy inconsistencies:\n"
+ diff
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
)
def get_model_list(filename, start_prompt, end_prompt):
"""Extracts the model list from the README."""
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start of the list.
start_index = 0
while not lines[start_index].startswith(start_prompt):
start_index += 1
start_index += 1
result = []
current_line = ""
end_index = start_index
while not lines[end_index].startswith(end_prompt):
if lines[end_index].startswith("1."):
if len(current_line) > 1:
result.append(current_line)
current_line = lines[end_index]
elif len(lines[end_index]) > 1:
current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}"
end_index += 1
if len(current_line) > 1:
result.append(current_line)
return "".join(result)
def convert_to_localized_md(model_list, localized_model_list, format_str):
"""Convert `model_list` to each localized README."""
def _rep(match):
title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups()
return format_str.format(
title=title,
model_link=model_link,
paper_affiliations=paper_affiliations,
paper_title_link=paper_title_link,
paper_authors=paper_authors,
supplements=" " + supplements.strip() if len(supplements) != 0 else "",
)
# This regex captures metadata from an English model description, including model title, model link,
# affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example).
_re_capture_meta = re.compile(
r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$"
)
# This regex is used to synchronize link.
_re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*")
if len(localized_model_list) == 0:
localized_model_index = {}
else:
try:
localized_model_index = {
re.search(r"\*\*\[([^\]]*)", line).groups()[0]: line
for line in localized_model_list.strip().split("\n")
}
except AttributeError:
raise AttributeError("A model name in localized READMEs cannot be recognized.")
model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")]
# We exclude keys in localized README not in the main one.
readmes_match = not any([k not in model_keys for k in localized_model_index])
localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys}
for model in model_list.strip().split("\n"):
title, model_link = _re_capture_title_link.search(model).groups()
if title not in localized_model_index:
readmes_match = False
# Add an anchor white space behind a model description string for regex.
# If metadata cannot be captured, the English version will be directly copied.
localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ")
else:
# Synchronize link
localized_model_index[title] = _re_capture_title_link.sub(
f"**[{title}]({model_link})**", localized_model_index[title], count=1
)
sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower())
return readmes_match, "\n".join(map(lambda x: x[1], sorted_index)) + "\n"
def convert_readme_to_index(model_list):
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
return model_list.replace("https://huggingface.co/docs/transformers/", "")
def _find_text_in_file(filename, start_prompt, end_prompt):
"""
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
lines.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start prompt.
start_index = 0
while not lines[start_index].startswith(start_prompt):
start_index += 1
start_index += 1
end_index = start_index
while not lines[end_index].startswith(end_prompt):
end_index += 1
end_index -= 1
while len(lines[start_index]) <= 1:
start_index += 1
while len(lines[end_index]) <= 1:
end_index -= 1
end_index += 1
return "".join(lines[start_index:end_index]), start_index, end_index, lines
def check_model_list_copy(overwrite=False, max_per_line=119):
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
# Fix potential doc links in the README
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
readme = f.read()
new_readme = readme.replace("https://huggingface.co/transformers", "https://huggingface.co/docs/transformers")
new_readme = new_readme.replace(
"https://huggingface.co/docs/main/transformers", "https://huggingface.co/docs/transformers/main"
)
if new_readme != readme:
if overwrite:
with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f:
f.write(new_readme)
else:
raise ValueError(
"The main README contains wrong links to the documentation of Transformers. Run `make fix-copies` to "
"automatically fix them."
)
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
index_list, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "index.mdx"),
start_prompt="<!--This list is updated automatically from the README",
end_prompt="### Supported frameworks",
)
md_list = get_model_list(
filename="README.md",
start_prompt=LOCALIZED_READMES["README.md"]["start_prompt"],
end_prompt=LOCALIZED_READMES["README.md"]["end_prompt"],
)
converted_md_lists = []
for filename, value in LOCALIZED_READMES.items():
_start_prompt = value["start_prompt"]
_end_prompt = value["end_prompt"]
_format_model_list = value["format_model_list"]
localized_md_list = get_model_list(filename, _start_prompt, _end_prompt)
readmes_match, converted_md_list = convert_to_localized_md(md_list, localized_md_list, _format_model_list)
converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt))
converted_md_list = convert_readme_to_index(md_list)
if converted_md_list != index_list:
if overwrite:
with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [converted_md_list] + lines[end_index:])
else:
raise ValueError(
"The model list in the README changed and the list in `index.mdx` has not been updated. Run "
"`make fix-copies` to fix this."
)
for converted_md_list in converted_md_lists:
filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list
if filename == "README.md":
continue
if overwrite:
_, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(REPO_PATH, filename), start_prompt=_start_prompt, end_prompt=_end_prompt
)
with open(os.path.join(REPO_PATH, filename), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [converted_md] + lines[end_index:])
elif not readmes_match:
raise ValueError(
f"The model list in the README changed and the list in `{filename}` has not been updated. Run "
"`make fix-copies` to fix this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_copies(args.fix_and_overwrite)
check_full_copies(args.fix_and_overwrite)

168
utils/check_dummies.py Normal file
View File

@ -0,0 +1,168 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers"
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)")
DUMMY_CONSTANT = """
{0} = None
"""
DUMMY_CLASS = """
class {0}(metaclass=DummyObject):
_backends = {1}
def __init__(self, *args, **kwargs):
requires_backends(self, {1})
"""
DUMMY_FUNCTION = """
def {0}(*args, **kwargs):
requires_backends({0}, {1})
"""
def find_backend(line):
"""Find one (or multiple) backend in a code line of the init."""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort()
return "_and_".join(backends)
def read_init():
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Get to the point we do the actual imports for type checking
line_index = 0
while not lines[line_index].startswith("if TYPE_CHECKING"):
line_index += 1
backend_specific_objects = {}
# Go through the end of the file
while line_index < len(lines):
# If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
if backend is not None:
while not lines[line_index].startswith(" else:"):
line_index += 1
line_index += 1
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
line = lines[line_index]
single_line_import_search = _re_single_line_import.search(line)
if single_line_import_search is not None:
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 12):
objects.append(line[12:-2])
line_index += 1
backend_specific_objects[backend] = objects
else:
line_index += 1
return backend_specific_objects
def create_dummy_object(name, backend_name):
"""Create the code for the dummy object corresponding to `name`."""
if name.isupper():
return DUMMY_CONSTANT.format(name)
elif name.islower():
return DUMMY_FUNCTION.format(name, backend_name)
else:
return DUMMY_CLASS.format(name, backend_name)
def create_dummy_files():
"""Create the content of the dummy files."""
backend_specific_objects = read_init()
# For special correspondence backend to module name as used in the function requires_modulename
dummy_files = {}
for backend, objects in backend_specific_objects.items():
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += "# flake8: noqa\n"
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file
return dummy_files
def check_dummies(overwrite=False):
"""Check if the dummy files are up to date and maybe `overwrite` with the right content."""
dummy_files = create_dummy_files()
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
short_names = {"torch": "pt"}
# Locate actual dummy modules and read their content.
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
dummy_file_paths = {
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
for backend in dummy_files.keys()
}
actual_dummies = {}
for backend, file_path in dummy_file_paths.items():
if os.path.isfile(file_path):
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
actual_dummies[backend] = f.read()
else:
actual_dummies[backend] = ""
for backend in dummy_files.keys():
if dummy_files[backend] != actual_dummies[backend]:
if overwrite:
print(
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
"__init__ has new objects."
)
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
f.write(dummy_files[backend])
else:
raise ValueError(
"The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
"to fix this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_dummies(args.fix_and_overwrite)

299
utils/check_inits.py Normal file
View File

@ -0,0 +1,299 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import importlib.util
import os
import re
from pathlib import Path
PATH_TO_TRANSFORMERS = "src/transformers"
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Catches a one-line _import_struct = {xxx}
_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}")
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
# Catches a line if not is_foo_available
_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]")
# Catches a line with an object between quotes and a comma: "MyModel",
_re_quote_object = re.compile('^\s+"([^"]+)",')
# Catches a line with objects between brackets only: ["foo", "bar"],
_re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
# Catches a line with from foo import bar, bla, boo
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
# Catches a line with try:
_re_try = re.compile(r"^\s*try:")
# Catches a line with else:
_re_else = re.compile(r"^\s*else:")
def find_backend(line):
"""Find one (or multiple) backend in a code line of the init."""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort()
return "_and_".join(backends)
def parse_init(init_file):
"""
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
defined
"""
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
line_index = 0
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
line_index += 1
# If this is a traditional init, just return.
if line_index >= len(lines):
return None
# First grab the objects without a specific backend in _import_structure
objects = []
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
line = lines[line_index]
# If we have everything on a single line, let's deal with it.
if _re_one_line_import_struct.search(line):
content = _re_one_line_import_struct.search(line).groups()[0]
imports = re.findall("\[([^\]]+)\]", content)
for imp in imports:
objects.extend([obj[1:-1] for obj in imp.split(", ")])
line_index += 1
continue
single_line_import_search = _re_import_struct_key_value.search(line)
if single_line_import_search is not None:
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
objects.extend(imports)
elif line.startswith(" " * 8 + '"'):
objects.append(line[9:-3])
line_index += 1
import_dict_objects = {"none": objects}
# Let's continue with backend-specific objects in _import_structure
while not lines[line_index].startswith("if TYPE_CHECKING"):
# If the line is an if not is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
# Check if the backend declaration is inside a try block:
if _re_try.search(lines[line_index - 1]) is None:
backend = None
if backend is not None:
line_index += 1
# Scroll until we hit the else block of try-except-else
while _re_else.search(lines[line_index]) is None:
line_index += 1
line_index += 1
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
line = lines[line_index]
if _re_import_struct_add_one.search(line) is not None:
objects.append(_re_import_struct_add_one.search(line).groups()[0])
elif _re_import_struct_add_many.search(line) is not None:
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
objects.extend(imports)
elif _re_between_brackets.search(line) is not None:
imports = _re_between_brackets.search(line).groups()[0].split(", ")
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
objects.extend(imports)
elif _re_quote_object.search(line) is not None:
objects.append(_re_quote_object.search(line).groups()[0])
elif line.startswith(" " * 8 + '"'):
objects.append(line[9:-3])
elif line.startswith(" " * 12 + '"'):
objects.append(line[13:-3])
line_index += 1
import_dict_objects[backend] = objects
else:
line_index += 1
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend
objects = []
while (
line_index < len(lines)
and find_backend(lines[line_index]) is None
and not lines[line_index].startswith("else")
):
line = lines[line_index]
single_line_import_search = _re_import.search(line)
if single_line_import_search is not None:
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 8):
objects.append(line[8:-2])
line_index += 1
type_hint_objects = {"none": objects}
# Let's continue with backend-specific objects
while line_index < len(lines):
# If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
# Check if the backend declaration is inside a try block:
if _re_try.search(lines[line_index - 1]) is None:
backend = None
if backend is not None:
line_index += 1
# Scroll until we hit the else block of try-except-else
while _re_else.search(lines[line_index]) is None:
line_index += 1
line_index += 1
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
line = lines[line_index]
single_line_import_search = _re_import.search(line)
if single_line_import_search is not None:
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 12):
objects.append(line[12:-2])
line_index += 1
type_hint_objects[backend] = objects
else:
line_index += 1
return import_dict_objects, type_hint_objects
def analyze_results(import_dict_objects, type_hint_objects):
"""
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
"""
def find_duplicates(seq):
return [k for k, v in collections.Counter(seq).items() if v > 1]
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
return ["Both sides of the init do not have the same backends!"]
errors = []
for key in import_dict_objects.keys():
duplicate_imports = find_duplicates(import_dict_objects[key])
if duplicate_imports:
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}")
duplicate_type_hints = find_duplicates(type_hint_objects[key])
if duplicate_type_hints:
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}")
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])):
name = "base imports" if key == "none" else f"{key} backend"
errors.append(f"Differences for {name}:")
for a in type_hint_objects[key]:
if a not in import_dict_objects[key]:
errors.append(f" {a} in TYPE_HINT but not in _import_structure.")
for a in import_dict_objects[key]:
if a not in type_hint_objects[key]:
errors.append(f" {a} in _import_structure but not in TYPE_HINT.")
return errors
def check_all_inits():
"""
Check all inits in the transformers repo and raise an error if at least one does not define the same objects in
both halves.
"""
failures = []
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
if "__init__.py" in files:
fname = os.path.join(root, "__init__.py")
objects = parse_init(fname)
if objects is not None:
errors = analyze_results(*objects)
if len(errors) > 0:
errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}"
failures.append("\n".join(errors))
if len(failures) > 0:
raise ValueError("\n\n".join(failures))
def get_transformers_submodules():
"""
Returns the list of Transformers submodules.
"""
submodules = []
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS):
for folder in directories:
# Ignore private modules
if folder.startswith("_"):
directories.remove(folder)
continue
# Ignore leftovers from branches (empty folders apart from pycache)
if len(list((Path(path) / folder).glob("*.py"))) == 0:
continue
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS))
submodule = short_path.replace(os.path.sep, ".")
submodules.append(submodule)
for fname in files:
if fname == "__init__.py":
continue
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
if len(submodule.split(".")) == 1:
submodules.append(submodule)
return submodules
IGNORE_SUBMODULES = [
"convert_pytorch_checkpoint_to_tf2",
"modeling_flax_pytorch_utils",
]
def check_submodules():
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
submodule_search_locations=[PATH_TO_TRANSFORMERS],
)
transformers = spec.loader.load_module()
module_not_registered = [
module
for module in get_transformers_submodules()
if module not in IGNORE_SUBMODULES and module not in transformers._import_structure.keys()
]
if len(module_not_registered) > 0:
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered)
raise ValueError(
"The following submodules are not properly registed in the main init of Transformers:\n"
f"{list_of_modules}\n"
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value."
)
if __name__ == "__main__":
check_all_inits()
check_submodules()

762
utils/check_repo.py Normal file
View File

@ -0,0 +1,762 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
import re
import warnings
from collections import OrderedDict
from difflib import get_close_matches
from pathlib import Path
from transformers import is_flax_available, is_tf_available, is_torch_available
from transformers.models.auto import get_values
from transformers.utils import ENV_VARS_TRUE_VALUES
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_repo.py
PATH_TO_TRANSFORMERS = "src/transformers"
PATH_TO_TESTS = "tests"
PATH_TO_DOC = "docs/source/en"
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
"RealmBertModel",
"T5Stack",
"TFDPRSpanPredictor",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"OPTDecoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
"SegformerDecodeHead", # Building part of bigger (tested) model.
"PLBartEncoder", # Building part of bigger (tested) model.
"PLBartDecoder", # Building part of bigger (tested) model.
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
"DetrEncoder", # Building part of bigger (tested) model.
"DetrDecoder", # Building part of bigger (tested) model.
"DetrDecoderWrapper", # Building part of bigger (tested) model.
"M2M100Encoder", # Building part of bigger (tested) model.
"M2M100Decoder", # Building part of bigger (tested) model.
"Speech2TextEncoder", # Building part of bigger (tested) model.
"Speech2TextDecoder", # Building part of bigger (tested) model.
"LEDEncoder", # Building part of bigger (tested) model.
"LEDDecoder", # Building part of bigger (tested) model.
"BartDecoderWrapper", # Building part of bigger (tested) model.
"BartEncoder", # Building part of bigger (tested) model.
"BertLMHeadModel", # Needs to be setup as decoder.
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
"BlenderbotEncoder", # Building part of bigger (tested) model.
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
"MBartEncoder", # Building part of bigger (tested) model.
"MBartDecoderWrapper", # Building part of bigger (tested) model.
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
"MegatronBertEncoder", # Building part of bigger (tested) model.
"MegatronBertDecoder", # Building part of bigger (tested) model.
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
"PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"RealmBertModel", # Building part of bigger (tested) model.
"RealmReader", # Not regular model.
"RealmScorer", # Not regular model.
"RealmForOpenQA", # Not regular model.
"ReformerForMaskedLM", # Needs to be setup as decoder.
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
"SeparableConv1D", # Building part of bigger (tested) model.
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
"OPTDecoderWrapper",
]
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
# trigger the common tests.
TEST_FILES_WITH_NO_COMMON_TESTS = [
"models/decision_transformer/test_modeling_decision_transformer.py",
"models/camembert/test_modeling_camembert.py",
"models/mt5/test_modeling_flax_mt5.py",
"models/mbart/test_modeling_mbart.py",
"models/mt5/test_modeling_mt5.py",
"models/pegasus/test_modeling_pegasus.py",
"models/camembert/test_modeling_tf_camembert.py",
"models/mt5/test_modeling_tf_mt5.py",
"models/xlm_roberta/test_modeling_tf_xlm_roberta.py",
"models/xlm_roberta/test_modeling_flax_xlm_roberta.py",
"models/xlm_prophetnet/test_modeling_xlm_prophetnet.py",
"models/xlm_roberta/test_modeling_xlm_roberta.py",
"models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
"models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
"models/decision_transformer/test_modeling_decision_transformer.py",
]
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"DPTForDepthEstimation",
"DecisionTransformerGPT2Model",
"GLPNForDepthEstimation",
"ViltForQuestionAnswering",
"ViltForImagesAndTextClassification",
"ViltForImageAndTextRetrieval",
"ViltForMaskedLM",
"XGLMEncoder",
"XGLMDecoder",
"XGLMDecoderWrapper",
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"SegformerDecodeHead",
"FlaxBeitForMaskedImageModeling",
"PLBartEncoder",
"PLBartDecoder",
"PLBartDecoderWrapper",
"BeitForMaskedImageModeling",
"CLIPTextModel",
"CLIPVisionModel",
"TFCLIPTextModel",
"TFCLIPVisionModel",
"FlaxCLIPTextModel",
"FlaxCLIPVisionModel",
"FlaxWav2Vec2ForCTC",
"DetrForSegmentation",
"DPRReader",
"FlaubertForQuestionAnswering",
"FlavaImageCodebook",
"FlavaTextModel",
"FlavaImageModel",
"FlavaMultimodalModel",
"GPT2DoubleHeadsModel",
"LukeForMaskedLM",
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
"OpenAIGPTDoubleHeadsModel",
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
"RealmEmbedder",
"RealmForOpenQA",
"RealmScorer",
"RealmReader",
"TFDPRReader",
"TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel",
"TFRagModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
"Wav2Vec2ForCTC",
"HubertForCTC",
"SEWForCTC",
"SEWDForCTC",
"XLMForQuestionAnswering",
"XLNetForQuestionAnswering",
"SeparableConv1D",
"VisualBertForRegionToPhraseAlignment",
"VisualBertForVisualReasoning",
"VisualBertForQuestionAnswering",
"VisualBertForMultipleChoice",
"TFWav2Vec2ForCTC",
"TFHubertForCTC",
"MaskFormerForInstanceSegmentation",
]
# Update this list for models that have multiple model types for the same
# model doc
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
[
("data2vec-text", "data2vec"),
("data2vec-audio", "data2vec"),
("data2vec-vision", "data2vec"),
]
)
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
submodule_search_locations=[PATH_TO_TRANSFORMERS],
)
transformers = spec.loader.load_module()
def check_model_list():
"""Check the model list inside the transformers library."""
# Get the models from the directory structure of `src/transformers/models/`
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
_models = []
for model in os.listdir(models_dir):
model_dir = os.path.join(models_dir, model)
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
_models.append(model)
# Get the models from the directory structure of `src/transformers/models/`
models = [model for model in dir(transformers.models) if not model.startswith("__")]
missing_models = sorted(list(set(_models).difference(models)))
if missing_models:
raise Exception(
f"The following models should be included in {models_dir}/__init__.py: {','.join(missing_models)}."
)
# If some modeling modules should be ignored for all checks, they should be added in the nested list
# _ignore_modules of this function.
def get_model_modules():
"""Get the model modules inside the transformers library."""
_ignore_modules = [
"modeling_auto",
"modeling_encoder_decoder",
"modeling_marian",
"modeling_mmbt",
"modeling_outputs",
"modeling_retribert",
"modeling_utils",
"modeling_flax_auto",
"modeling_flax_encoder_decoder",
"modeling_flax_utils",
"modeling_speech_encoder_decoder",
"modeling_flax_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",
"modeling_tf_encoder_decoder",
"modeling_tf_outputs",
"modeling_tf_pytorch_utils",
"modeling_tf_utils",
"modeling_tf_transfo_xl_utilities",
"modeling_tf_vision_encoder_decoder",
"modeling_vision_encoder_decoder",
]
modules = []
for model in dir(transformers.models):
# There are some magic dunder attributes in the dir, we ignore them
if not model.startswith("__"):
model_module = getattr(transformers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
if inspect.ismodule(modeling_module):
modules.append(modeling_module)
return modules
def get_models(module, include_pretrained=False):
"""Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
continue
attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
models.append((attr_name, attr))
return models
def is_a_private_model(model):
"""Returns True if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
# Wrapper, Encoder and Decoder are all privates
if model.endswith("Wrapper"):
return True
if model.endswith("Encoder"):
return True
if model.endswith("Decoder"):
return True
return False
def check_models_are_in_init():
"""Checks all models defined in the library are in the main init."""
models_not_in_init = []
dir_transformers = dir(transformers)
for module in get_model_modules():
models_not_in_init += [
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
]
# Remove private models
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
if len(models_not_in_init) > 0:
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files():
"""Get the model test files.
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
"""
_ignore_files = [
"test_modeling_common",
"test_modeling_encoder_decoder",
"test_modeling_flax_encoder_decoder",
"test_modeling_flax_speech_encoder_decoder",
"test_modeling_marian",
"test_modeling_tf_common",
"test_modeling_tf_encoder_decoder",
]
test_files = []
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
model_test_root = os.path.join(PATH_TO_TESTS, "models")
model_test_dirs = []
for x in os.listdir(model_test_root):
x = os.path.join(model_test_root, x)
if os.path.isdir(x):
model_test_dirs.append(x)
for target_dir in [PATH_TO_TESTS] + model_test_dirs:
for file_or_dir in os.listdir(target_dir):
path = os.path.join(target_dir, file_or_dir)
if os.path.isfile(path):
filename = os.path.split(path)[-1]
if "test_modeling" in filename and not os.path.splitext(filename)[0] in _ignore_files:
file = os.path.join(*path.split(os.sep)[1:])
test_files.append(file)
return test_files
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
# for the all_model_classes variable.
def find_tested_models(test_file):
"""Parse the content of test_file to detect what's in all_model_classes"""
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
content = f.read()
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
# Check with one less parenthesis as well
all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
if len(all_models) > 0:
model_tested = []
for entry in all_models:
for line in entry.split(","):
name = line.strip()
if len(name) > 0:
model_tested.append(name)
return model_tested
def check_models_are_tested(module, test_file):
"""Check models defined in module are tested in test_file."""
# XxxPreTrainedModel are not tested
defined_models = get_models(module)
tested_models = find_tested_models(test_file)
if tested_models is None:
if test_file.replace(os.path.sep, "/") in TEST_FILES_WITH_NO_COMMON_TESTS:
return
return [
f"{test_file} should define `all_model_classes` to apply common tests to the models it tests. "
+ "If this intentional, add the test filename to `TEST_FILES_WITH_NO_COMMON_TESTS` in the file "
+ "`utils/check_repo.py`."
]
failures = []
for model_name, _ in defined_models:
if model_name not in tested_models and model_name not in IGNORE_NON_TESTED:
failures.append(
f"{model_name} is defined in {module.__name__} but is not tested in "
+ f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file."
+ "If common tests should not applied to that model, add its name to `IGNORE_NON_TESTED`"
+ "in the file `utils/check_repo.py`."
)
return failures
def check_all_models_are_tested():
"""Check all models are properly tested."""
modules = get_model_modules()
test_files = get_model_test_files()
failures = []
for module in modules:
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
if len(test_file) == 0:
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
elif len(test_file) > 1:
failures.append(f"{module.__name__} has several test files: {test_file}.")
else:
test_file = test_file[0]
new_failures = check_models_are_tested(module, test_file)
if new_failures is not None:
failures += new_failures
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def get_all_auto_configured_models():
"""Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available():
for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
if is_tf_available():
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
if is_flax_available():
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
return [cls for cls in result]
def ignore_unautoclassed(model_name):
"""Rules to determine if `name` should be in an auto class."""
# Special white list
if model_name in IGNORE_NON_AUTO_CONFIGURED:
return True
# Encoder and Decoder should be ignored
if "Encoder" in model_name or "Decoder" in model_name:
return True
return False
def check_models_are_auto_configured(module, all_auto_models):
"""Check models defined in module are each in an auto class."""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
if model_name not in all_auto_models and not ignore_unautoclassed(model_name):
failures.append(
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
"`utils/check_repo.py`."
)
return failures
def check_all_models_are_auto_configured():
"""Check all models are each in an auto class."""
missing_backends = []
if not is_torch_available():
missing_backends.append("PyTorch")
if not is_tf_available():
missing_backends.append("TensorFlow")
if not is_flax_available():
missing_backends.append("Flax")
if len(missing_backends) > 0:
missing = ", ".join(missing_backends)
if os.getenv("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
raise Exception(
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
f"Transformers repo, the following are missing: {missing}."
)
else:
warnings.warn(
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
f"Transformers repo, the following are missing: {missing}. While it's probably fine as long as you "
"didn't make any change in one of those backends modeling files, you should probably execute the "
"command above to be on the safe side."
)
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
failures = []
for module in modules:
new_failures = check_models_are_auto_configured(module, all_auto_models)
if new_failures is not None:
failures += new_failures
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
def check_decorator_order(filename):
"""Check that in the test file `filename` the slow decorator is always last."""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
decorator_before = None
errors = []
for i, line in enumerate(lines):
search = _re_decorator.search(line)
if search is not None:
decorator_name = search.groups()[0]
if decorator_before is not None and decorator_name.startswith("parameterized"):
errors.append(i)
decorator_before = decorator_name
elif decorator_before is not None:
decorator_before = None
return errors
def check_all_decorator_order():
"""Check that in all test files, the slow decorator is always last."""
errors = []
for fname in os.listdir(PATH_TO_TESTS):
if fname.endswith(".py"):
filename = os.path.join(PATH_TO_TESTS, fname)
new_errors = check_decorator_order(filename)
errors += [f"- {filename}, line {i}" for i in new_errors]
if len(errors) > 0:
msg = "\n".join(errors)
raise ValueError(
"The parameterized decorator (and its variants) should always be first, but this is not the case in the"
f" following files:\n{msg}"
)
def find_all_documented_objects():
"""Parse the content of all doc files to detect which classes and functions it documents"""
documented_obj = []
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
content = f.read()
raw_doc_objs = re.findall(r"(?:autoclass|autofunction):: transformers.(\S+)\s+", content)
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
for doc_file in Path(PATH_TO_DOC).glob("**/*.mdx"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
content = f.read()
raw_doc_objs = re.findall("\[\[autodoc\]\]\s+(\S+)\s+", content)
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
return documented_obj
# One good reason for not being documented is to be deprecated. Put in this list deprecated objects.
DEPRECATED_OBJECTS = [
"AutoModelWithLMHead",
"BartPretrainedModel",
"DataCollator",
"DataCollatorForSOP",
"GlueDataset",
"GlueDataTrainingArguments",
"LineByLineTextDataset",
"LineByLineWithRefDataset",
"LineByLineWithSOPTextDataset",
"PretrainedBartModel",
"PretrainedFSMTModel",
"SingleSentenceClassificationProcessor",
"SquadDataTrainingArguments",
"SquadDataset",
"SquadExample",
"SquadFeatures",
"SquadV1Processor",
"SquadV2Processor",
"TFAutoModelWithLMHead",
"TFBartPretrainedModel",
"TextDataset",
"TextDatasetForNextSentencePrediction",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2Tokenizer",
"glue_compute_metrics",
"glue_convert_examples_to_features",
"glue_output_modes",
"glue_processors",
"glue_tasks_num_labels",
"squad_convert_examples_to_features",
"xnli_compute_metrics",
"xnli_output_modes",
"xnli_processors",
"xnli_tasks_num_labels",
"TFTrainer",
"TFTrainingArguments",
]
# Exceptionally, some objects should not be documented after all rules passed.
# ONLY PUT SOMETHING IN THIS LIST AS A LAST RESORT!
UNDOCUMENTED_OBJECTS = [
"AddedToken", # This is a tokenizers class.
"BasicTokenizer", # Internal, should never have been in the main init.
"CharacterTokenizer", # Internal, should never have been in the main init.
"DPRPretrainedReader", # Like an Encoder.
"DummyObject", # Just picked by mistake sometimes.
"MecabTokenizer", # Internal, should never have been in the main init.
"ModelCard", # Internal type.
"SqueezeBertModule", # Internal building block (should have been called SqueezeBertLayer)
"TFDPRPretrainedReader", # Like an Encoder.
"TransfoXLCorpus", # Internal type.
"WordpieceTokenizer", # Internal, should never have been in the main init.
"absl", # External module
"add_end_docstrings", # Internal, should never have been in the main init.
"add_start_docstrings", # Internal, should never have been in the main init.
"cached_path", # Internal used for downloading models.
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
"logger", # Internal logger
"logging", # External module
"requires_backends", # Internal function
]
# This list should be empty. Objects in it should get their own doc page.
SHOULD_HAVE_THEIR_OWN_PAGE = [
# Benchmarks
"PyTorchBenchmark",
"PyTorchBenchmarkArguments",
"TensorFlowBenchmark",
"TensorFlowBenchmarkArguments",
]
def ignore_undocumented(name):
"""Rules to determine if `name` should be undocumented."""
# NOT DOCUMENTED ON PURPOSE.
# Constants uppercase are not documented.
if name.isupper():
return True
# PreTrainedModels / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
if (
name.endswith("PreTrainedModel")
or name.endswith("Decoder")
or name.endswith("Encoder")
or name.endswith("Layer")
or name.endswith("Embeddings")
or name.endswith("Attention")
):
return True
# Submodules are not documented.
if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile(
os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py")
):
return True
# All load functions are not documented.
if name.startswith("load_tf") or name.startswith("load_pytorch"):
return True
# is_xxx_available functions are not documented.
if name.startswith("is_") and name.endswith("_available"):
return True
# Deprecated objects are not documented.
if name in DEPRECATED_OBJECTS or name in UNDOCUMENTED_OBJECTS:
return True
# MMBT model does not really work.
if name.startswith("MMBT"):
return True
if name in SHOULD_HAVE_THEIR_OWN_PAGE:
return True
return False
def check_all_objects_are_documented():
"""Check all models are properly documented."""
documented_objs = find_all_documented_objects()
modules = transformers._modules
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
if len(undocumented_objs) > 0:
raise Exception(
"The following objects are in the public init so should be documented:\n - "
+ "\n - ".join(undocumented_objs)
)
check_docstrings_are_in_md()
check_model_type_doc_match()
def check_model_type_doc_match():
"""Check all doc pages have a corresponding model type."""
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")]
model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
errors = []
for m in model_docs:
if m not in model_types and m != "auto":
close_matches = get_close_matches(m, model_types)
error_message = f"{m} is not a proper model identifier."
if len(close_matches) > 0:
close_matches = "/".join(close_matches)
error_message += f" Did you mean {close_matches}?"
errors.append(error_message)
if len(errors) > 0:
raise ValueError(
"Some model doc pages do not match any existing model type:\n"
+ "\n".join(errors)
+ "\nYou can add any missing model type to the `MODEL_NAMES_MAPPING` constant in "
"models/auto/configuration_auto.py."
)
# Re pattern to catch :obj:`xx`, :class:`xx`, :func:`xx` or :meth:`xx`.
_re_rst_special_words = re.compile(r":(?:obj|func|class|meth):`([^`]+)`")
# Re pattern to catch things between double backquotes.
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
# Re pattern to catch example introduction.
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)
def is_rst_docstring(docstring):
"""
Returns `True` if `docstring` is written in rst.
"""
if _re_rst_special_words.search(docstring) is not None:
return True
if _re_double_backquotes.search(docstring) is not None:
return True
if _re_rst_example.search(docstring) is not None:
return True
return False
def check_docstrings_are_in_md():
"""Check all docstrings are in md"""
files_with_rst = []
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
with open(file, "r") as f:
code = f.read()
docstrings = code.split('"""')
for idx, docstring in enumerate(docstrings):
if idx % 2 == 0 or not is_rst_docstring(docstring):
continue
files_with_rst.append(file)
break
if len(files_with_rst) > 0:
raise ValueError(
"The following files have docstrings written in rst:\n"
+ "\n".join([f"- {f}" for f in files_with_rst])
+ "\nTo fix this run `doc-builder convert path_to_py_file` after installing `doc-builder`\n"
"(`pip install git+https://github.com/huggingface/doc-builder`)"
)
def check_repo_quality():
"""Check all models are properly tested and documented."""
print("Checking all models are included.")
check_model_list()
print("Checking all models are public.")
check_models_are_in_init()
print("Checking all models are properly tested.")
check_all_decorator_order()
check_all_models_are_tested()
print("Checking all objects are properly documented.")
check_all_objects_are_documented()
print("Checking all models are in at least one auto class.")
check_all_models_are_auto_configured()
if __name__ == "__main__":
check_repo_quality()

232
utils/check_table.py Normal file
View File

@ -0,0 +1,232 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import collections
import importlib.util
import os
import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_table.py
TRANSFORMERS_PATH = "src/transformers"
PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "."
def _find_text_in_file(filename, start_prompt, end_prompt):
"""
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
lines.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start prompt.
start_index = 0
while not lines[start_index].startswith(start_prompt):
start_index += 1
start_index += 1
end_index = start_index
while not lines[end_index].startswith(end_prompt):
end_index += 1
end_index -= 1
while len(lines[start_index]) <= 1:
start_index += 1
while len(lines[end_index]) <= 1:
end_index -= 1
end_index += 1
return "".join(lines[start_index:end_index]), start_index, end_index, lines
# Add here suffixes that are used to identify models, seperated by |
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
# Regexes that match TF/Flax/PT model names.
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
submodule_search_locations=[TRANSFORMERS_PATH],
)
transformers_module = spec.loader.load_module()
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
def camel_case_split(identifier):
"Split a camelcased `identifier` into words."
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
return [m.group(0) for m in matches]
def _center_text(text, width):
text_length = 2 if text == "" or text == "" else len(text)
left_indent = (width - text_length) // 2
right_indent = width - text_length - left_indent
return " " * left_indent + text + " " * right_indent
def get_model_table_from_auto_modules():
"""Generates an up-to-date model table from the content of the auto modules."""
# Dictionary model names to config.
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
model_name_to_config = {
name: config_maping_names[code]
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
if code in config_maping_names
}
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers = collections.defaultdict(bool)
fast_tokenizers = collections.defaultdict(bool)
pt_models = collections.defaultdict(bool)
tf_models = collections.defaultdict(bool)
flax_models = collections.defaultdict(bool)
# Let's lookup through all transformers object (once).
for attr_name in dir(transformers_module):
lookup_dict = None
if attr_name.endswith("Tokenizer"):
lookup_dict = slow_tokenizers
attr_name = attr_name[:-9]
elif attr_name.endswith("TokenizerFast"):
lookup_dict = fast_tokenizers
attr_name = attr_name[:-13]
elif _re_tf_models.match(attr_name) is not None:
lookup_dict = tf_models
attr_name = _re_tf_models.match(attr_name).groups()[0]
elif _re_flax_models.match(attr_name) is not None:
lookup_dict = flax_models
attr_name = _re_flax_models.match(attr_name).groups()[0]
elif _re_pt_models.match(attr_name) is not None:
lookup_dict = pt_models
attr_name = _re_pt_models.match(attr_name).groups()[0]
if lookup_dict is not None:
while len(attr_name) > 0:
if attr_name in model_name_to_prefix.values():
lookup_dict[attr_name] = True
break
# Try again after removing the last word in the name
attr_name = "".join(camel_case_split(attr_name)[:-1])
# Let's build that table!
model_names = list(model_name_to_config.keys())
model_names.sort(key=str.lower)
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
widths = [len(c) + 2 for c in columns]
widths[0] = max([len(name) for name in model_names]) + 2
# Build the table per se
table = "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
# Use ":-----:" format to center-aligned table cell texts
table += "|" + "|".join([":" + "-" * (w - 2) + ":" for w in widths]) + "|\n"
check = {True: "", False: ""}
for name in model_names:
prefix = model_name_to_prefix[name]
line = [
name,
check[slow_tokenizers[prefix]],
check[fast_tokenizers[prefix]],
check[pt_models[prefix]],
check[tf_models[prefix]],
check[flax_models[prefix]],
]
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
return table
def check_model_table(overwrite=False):
"""Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
current_table, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "index.mdx"),
start_prompt="<!--This table is updated automatically from the auto modules",
end_prompt="<!-- End table-->",
)
new_table = get_model_table_from_auto_modules()
if current_table != new_table:
if overwrite:
with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
else:
raise ValueError(
"The model table in the `index.mdx` has not been updated. Run `make fix-copies` to fix this."
)
def has_onnx(model_type):
"""
Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not.
"""
config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING
if model_type not in config_mapping:
return False
config = config_mapping[model_type]
config_module = config.__module__
module = transformers_module
for part in config_module.split(".")[1:]:
module = getattr(module, part)
config_name = config.__name__
onnx_config_name = config_name.replace("Config", "OnnxConfig")
return hasattr(module, onnx_config_name)
def get_onnx_model_list():
"""
Return the list of models supporting ONNX.
"""
config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING
model_names = config_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING
onnx_model_types = [model_type for model_type in config_mapping.keys() if has_onnx(model_type)]
onnx_model_names = [model_names[model_type] for model_type in onnx_model_types]
onnx_model_names.sort(key=lambda x: x.upper())
return "\n".join([f"- {name}" for name in onnx_model_names]) + "\n"
def check_onnx_model_list(overwrite=False):
"""Check the model list in the serialization.mdx is consistent with the state of the lib and maybe `overwrite`."""
current_list, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"),
start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->",
end_prompt="In the next two sections, we'll show you how to:",
)
new_list = get_onnx_model_list()
if current_list != new_list:
if overwrite:
with open(os.path.join(PATH_TO_DOCS, "serialization.mdx"), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [new_list] + lines[end_index:])
else:
raise ValueError("The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_model_table(args.fix_and_overwrite)
check_onnx_model_list(args.fix_and_overwrite)

101
utils/check_tf_ops.py Normal file
View File

@ -0,0 +1,101 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
REPO_PATH = "."
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
INTERNAL_OPS = [
"Assert",
"AssignVariableOp",
"EmptyTensorList",
"MergeV2Checkpoints",
"ReadVariableOp",
"ResourceGather",
"RestoreV2",
"SaveV2",
"ShardedFilename",
"StatefulPartitionedCall",
"StaticRegexFullMatch",
"VarHandleOp",
]
def onnx_compliancy(saved_model_path, strict, opset):
saved_model = SavedModel()
onnx_ops = []
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
onnx_opsets = json.load(f)["opsets"]
for i in range(1, opset + 1):
onnx_ops.extend(onnx_opsets[str(i)])
with open(saved_model_path, "rb") as f:
saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
for meta_graph in saved_model.meta_graphs:
# Add operations in the graph definition
model_op_names.update(node.op for node in meta_graph.graph_def.node)
# Go through the functions in the graph definition
for func in meta_graph.graph_def.library.function:
# Add operations in each function
model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
incompatible_ops = []
for op in model_op_names:
if op not in onnx_ops and op not in INTERNAL_OPS:
incompatible_ops.append(op)
if strict and len(incompatible_ops) > 0:
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
elif len(incompatible_ops) > 0:
print(f"Found the following incompatible ops for the opset {opset}:")
print(*incompatible_ops, sep="\n")
else:
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
parser.add_argument(
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
)
parser.add_argument(
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
)
parser.add_argument(
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
)
args = parser.parse_args()
if args.framework == "onnx":
onnx_compliancy(args.saved_model_path, args.strict, args.opset)