diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..cf818346 --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Initially taken from Github's Python gitignore file + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# tests and logs +tests/fixtures/cached_*_text.txt +logs/ +lightning_logs/ +lang_code_data/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# vscode +.vs +.vscode + +# Pycharm +.idea + +# TF code +tensorflow_code + +# Models +proc_data + +# examples +runs +/runs_old +/wandb +/examples/runs +/examples/**/*.args +/examples/rag/sweep + +# data +/data +serialization_dir + +# emacs +*.*~ +debug.env + +# vim +.*.swp + +#ctags +tags + +# pre-commit +.pre-commit* + +# .lock +*.lock + +# DS_Store (MacOS) +.DS_Store \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..f5c35730 --- /dev/null +++ b/Makefile @@ -0,0 +1,109 @@ +.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples + +# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) +export PYTHONPATH = src + +check_dirs := models tests src utils + +modified_only_fixup: + $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) + @if test -n "$(modified_py_files)"; then \ + echo "Checking/fixing $(modified_py_files)"; \ + black --preview $(modified_py_files); \ + isort $(modified_py_files); \ + flake8 $(modified_py_files); \ + else \ + echo "No library .py files were modified"; \ + fi + +# Update src/diffusers/dependency_versions_table.py + +deps_table_update: + @python setup.py deps_table_update + +deps_table_check_updated: + @md5sum src/diffusers/dependency_versions_table.py > md5sum.saved + @python setup.py deps_table_update + @md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1) + @rm md5sum.saved + +# autogenerating code + +autogenerate_code: deps_table_update + +# Check that the repo is in a good state + +repo-consistency: + python utils/check_copies.py + python utils/check_table.py + python utils/check_dummies.py + python utils/check_repo.py + python utils/check_inits.py + python utils/check_config_docstrings.py + python utils/tests_fetcher.py --sanity_check + +# this target runs checks on all files + +quality: + black --check --preview $(check_dirs) + isort --check-only $(check_dirs) + python utils/custom_init_isort.py --check_only + python utils/sort_auto_mappings.py --check_only + flake8 $(check_dirs) + doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source + +# Format source code automatically and check is there are any problems left that need manual fixing + +extra_style_checks: + python utils/custom_init_isort.py + python utils/sort_auto_mappings.py + doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source + +# this target runs checks on all files and potentially modifies some of them + +style: + black --preview $(check_dirs) + isort $(check_dirs) + ${MAKE} autogenerate_code + ${MAKE} extra_style_checks + +# Super fast fix and check target that only works on relevant modified files since the branch was made + +fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency + +# Make marked copies of snippets of codes conform to the original + +fix-copies: + python utils/check_copies.py --fix_and_overwrite + python utils/check_table.py --fix_and_overwrite + python utils/check_dummies.py --fix_and_overwrite + +# Run tests for the library + +test: + python -m pytest -n auto --dist=loadfile -s -v ./tests/ + +# Run tests for examples + +test-examples: + python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/ + +# Run tests for SageMaker DLC release + +test-sagemaker: # install sagemaker dependencies in advance with pip install .[sagemaker] + TEST_SAGEMAKER=True python -m pytest -n auto -s -v ./tests/sagemaker + + +# Release stuff + +pre-release: + python utils/release.py + +pre-patch: + python utils/release.py --patch + +post-release: + python utils/release.py --post_release + +post-patch: + python utils/release.py --post_release --patch diff --git a/md5sum.saved b/md5sum.saved new file mode 100644 index 00000000..2dc63471 --- /dev/null +++ b/md5sum.saved @@ -0,0 +1 @@ +ce075df80e7ba2391d63d026be165c15 src/diffusers/dependency_versions_table.py diff --git a/setup.py b/setup.py index 4021cd65..d902811b 100644 --- a/setup.py +++ b/setup.py @@ -12,8 +12,164 @@ # See the License for the specific language governing permissions and # limitations under the License. -from setuptools import setup -from setuptools import find_packages +""" +Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py + +To create the package for pypi. + +1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the + documentation. + + If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make + for the post-release and run `make fix-copies` on the main branch as well. + +2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid. + +3. Unpin specific versions from setup.py that use a git install. + +4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the + message: "Release: " and push. + +5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs) + +6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' " + Push the tag to git: git push --tags origin v-release + +7. Build both the sources and the wheel. Do not change anything in setup.py between + creating the wheel and the source distribution (obviously). + + For the wheel, run: "python setup.py bdist_wheel" in the top level directory. + (this will build a wheel for the python version you use to build it). + + For the sources, run: "python setup.py sdist" + You should now have a /dist directory with both .whl and .tar.gz source versions. + +8. Check that everything looks correct by uploading the package to the pypi test server: + + twine upload dist/* -r pypitest + (pypi suggest using twine as other methods upload files via plaintext.) + You may have to specify the repository url, use the following command then: + twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ + + Check that you can install it in a virtualenv by running: + pip install -i https://testpypi.python.org/pypi transformers + + Check you can run the following commands: + python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))" + python -c "from transformers import *" + +9. Upload the final version to actual pypi: + twine upload dist/* -r pypi + +10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. + +11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release, + you need to go back to main before executing this. +""" + +import re +from distutils.core import Command + +from setuptools import find_packages, setup + +# IMPORTANT: +# 1. all dependencies should be listed here with their version requirements if any +# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py +_deps = [ + "Pillow", + "accelerate>=0.9.0", + "black~=22.0,>=22.3", + "codecarbon==1.2.0", + "dataclasses", + "datasets", + "GitPython<3.1.19", + "hf-doc-builder>=0.3.0", + "huggingface-hub>=0.1.0,<1.0", + "importlib_metadata", + "isort>=5.5.4", + "numpy>=1.17", + "pytest", + "pytest-timeout", + "pytest-xdist", + "python>=3.7.0", + "regex!=2019.12.17", + "requests", + "sagemaker>=2.31.0", + "tokenizers>=0.11.1,!=0.11.3,<0.13", + "torch>=1.4", + "torchaudio", + "tqdm>=4.27", + "unidic>=1.0.2", + "unidic_lite>=1.0.7", + "uvicorn", +] + +# this is a lookup table with items like: +# +# tokenizers: "tokenizers==0.9.4" +# packaging: "packaging" +# +# some of the values are versioned whereas others aren't. +deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)} + +# since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from +# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with: +# +# python -c 'import sys; from diffusers.dependency_versions_table import deps; \ +# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets +# +# Just pass the desired package names to that script as it's shown with 2 packages above. +# +# If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above +# +# You can then feed this for example to `pip`: +# +# pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \ +# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets) +# + + +def deps_list(*pkgs): + return [deps[pkg] for pkg in pkgs] + + +class DepsTableUpdateCommand(Command): + """ + A custom distutils command that updates the dependency table. + usage: python setup.py deps_table_update + """ + + description = "build runtime dependency table" + user_options = [ + # format: (long option, short option, description). + ("dep-table-update", None, "updates src/diffusers/dependency_versions_table.py"), + ] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()]) + content = [ + "# THIS FILE HAS BEEN AUTOGENERATED. To update:", + "# 1. modify the `_deps` dict in setup.py", + "# 2. run `make deps_table_update``", + "deps = {", + entries, + "}", + "", + ] + target = "src/diffusers/dependency_versions_table.py" + print(f"updating {target}") + with open(target, "w", encoding="utf-8", newline="\n") as f: + f.write("\n".join(content)) + + +extras = {} + extras = {} extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"] @@ -61,6 +217,7 @@ setup( "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], + cmdclass={"deps_table_update": DepsTableUpdateCommand}, ) # Release checklist diff --git a/src/diffusers/dependency_versions_check.py b/src/diffusers/dependency_versions_check.py new file mode 100644 index 00000000..bbf86322 --- /dev/null +++ b/src/diffusers/dependency_versions_check.py @@ -0,0 +1,47 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from .dependency_versions_table import deps +from .utils.versions import require_version, require_version_core + + +# define which module versions we always want to check at run time +# (usually the ones defined in `install_requires` in setup.py) +# +# order specific notes: +# - tqdm must be checked before tokenizers + +pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() +if sys.version_info < (3, 7): + pkgs_to_check_at_runtime.append("dataclasses") +if sys.version_info < (3, 8): + pkgs_to_check_at_runtime.append("importlib_metadata") + +for pkg in pkgs_to_check_at_runtime: + if pkg in deps: + if pkg == "tokenizers": + # must be loaded here, or else tqdm check may fail + from .utils import is_tokenizers_available + + if not is_tokenizers_available(): + continue # not required, check version only if installed + + require_version_core(deps[pkg]) + else: + raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") + + +def dep_version_check(pkg, hint=None): + require_version(deps[pkg], hint) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py new file mode 100644 index 00000000..6b552e0d --- /dev/null +++ b/src/diffusers/dependency_versions_table.py @@ -0,0 +1,31 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify the `_deps` dict in setup.py +# 2. run `make deps_table_update`` +deps = { + "Pillow": "Pillow", + "accelerate": "accelerate>=0.9.0", + "black": "black~=22.0,>=22.3", + "codecarbon": "codecarbon==1.2.0", + "dataclasses": "dataclasses", + "datasets": "datasets", + "GitPython": "GitPython<3.1.19", + "hf-doc-builder": "hf-doc-builder>=0.3.0", + "huggingface-hub": "huggingface-hub>=0.1.0,<1.0", + "importlib_metadata": "importlib_metadata", + "isort": "isort>=5.5.4", + "numpy": "numpy>=1.17", + "pytest": "pytest", + "pytest-timeout": "pytest-timeout", + "pytest-xdist": "pytest-xdist", + "python": "python>=3.7.0", + "regex": "regex!=2019.12.17", + "requests": "requests", + "sagemaker": "sagemaker>=2.31.0", + "tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13", + "torch": "torch>=1.4", + "torchaudio": "torchaudio", + "tqdm": "tqdm>=4.27", + "unidic": "unidic>=1.0.2", + "unidic_lite": "unidic_lite>=1.0.7", + "uvicorn": "uvicorn", +} diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index f34bc16a..eff34741 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -class UNetModel: +class UNetModel: def __init__(self, config): self.config = config print("I can diffuse!") diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py new file mode 100644 index 00000000..382f42bf --- /dev/null +++ b/utils/check_config_docstrings.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +import re + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_config_docstrings.py +PATH_TO_TRANSFORMERS = "src/transformers" + + +# This is to make sure the transformers module imported is the one in the repo. +spec = importlib.util.spec_from_file_location( + "transformers", + os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), + submodule_search_locations=[PATH_TO_TRANSFORMERS], +) +transformers = spec.loader.load_module() + +CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING + +# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`. +# For example, `[bert-base-uncased](https://huggingface.co/bert-base-uncased)` +_re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)") + + +CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { + "CLIPConfig", + "DecisionTransformerConfig", + "EncoderDecoderConfig", + "RagConfig", + "SpeechEncoderDecoderConfig", + "VisionEncoderDecoderConfig", + "VisionTextDualEncoderConfig", +} + + +def check_config_docstrings_have_checkpoints(): + configs_without_checkpoint = [] + + for config_class in list(CONFIG_MAPPING.values()): + checkpoint_found = False + + # source code of `config_class` + config_source = inspect.getsource(config_class) + checkpoints = _re_checkpoint.findall(config_source) + + for checkpoint in checkpoints: + # Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link. + # For example, `('bert-base-uncased', 'https://huggingface.co/bert-base-uncased')` + ckpt_name, ckpt_link = checkpoint + + # verify the checkpoint name corresponds to the checkpoint link + ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}" + if ckpt_link == ckpt_link_from_name: + checkpoint_found = True + break + + name = config_class.__name__ + if not checkpoint_found and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK: + configs_without_checkpoint.append(name) + + if len(configs_without_checkpoint) > 0: + message = "\n".join(sorted(configs_without_checkpoint)) + raise ValueError(f"The following configurations don't contain any valid checkpoint:\n{message}") + + +if __name__ == "__main__": + check_config_docstrings_have_checkpoints() diff --git a/utils/check_copies.py b/utils/check_copies.py new file mode 100644 index 00000000..7565bfa5 --- /dev/null +++ b/utils/check_copies.py @@ -0,0 +1,458 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import os +import re + +import black +from doc_builder.style_doc import style_docstrings_in_code + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_copies.py +TRANSFORMERS_PATH = "src/transformers" +PATH_TO_DOCS = "docs/source/en" +REPO_PATH = "." + +# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with) +FULL_COPIES = { + "examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py", + "examples/flax/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py", +} + + +LOCALIZED_READMES = { + # If the introduction or the conclusion of the list change, the prompts may need to be updated. + "README.md": { + "start_prompt": "🤗 Transformers currently provides the following architectures", + "end_prompt": "1. Want to contribute a new model?", + "format_model_list": ( + "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" + " {paper_authors}.{supplements}" + ), + }, + "README_zh-hans.md": { + "start_prompt": "🤗 Transformers 目前支持如下的架构", + "end_prompt": "1. 想要贡献新的模型?", + "format_model_list": ( + "**[{title}]({model_link})** (来自 {paper_affiliations}) 伴随论文 {paper_title_link} 由 {paper_authors}" + " 发布。{supplements}" + ), + }, + "README_zh-hant.md": { + "start_prompt": "🤗 Transformers 目前支援以下的架構", + "end_prompt": "1. 想要貢獻新的模型?", + "format_model_list": ( + "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" + " {paper_authors}.{supplements}" + ), + }, + "README_ko.md": { + "start_prompt": "🤗 Transformers는 다음 모델들을 제공합니다", + "end_prompt": "1. 새로운 모델을 올리고 싶나요?", + "format_model_list": ( + "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" + " {paper_authors}.{supplements}" + ), + }, +} + + +def _should_continue(line, indent): + return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None + + +def find_code_in_transformers(object_name): + """Find and return the code source code of `object_name`.""" + parts = object_name.split(".") + i = 0 + + # First let's find the module where our object lives. + module = parts[i] + while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")): + i += 1 + if i < len(parts): + module = os.path.join(module, parts[i]) + if i >= len(parts): + raise ValueError( + f"`object_name` should begin with the name of a module of transformers but got {object_name}." + ) + + with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Now let's find the class / func in the code! + indent = "" + line_index = 0 + for name in parts[i + 1 :]: + while ( + line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None + ): + line_index += 1 + indent += " " + line_index += 1 + + if line_index >= len(lines): + raise ValueError(f" {object_name} does not match any function or class in {module}.") + + # We found the beginning of the class / func, now let's find the end (when the indent diminishes). + start_index = line_index + while line_index < len(lines) and _should_continue(lines[line_index], indent): + line_index += 1 + # Clean up empty lines at the end (if any). + while len(lines[line_index - 1]) <= 1: + line_index -= 1 + + code_lines = lines[start_index:line_index] + return "".join(code_lines) + + +_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)") +_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") + + +def get_indent(code): + lines = code.split("\n") + idx = 0 + while idx < len(lines) and len(lines[idx]) == 0: + idx += 1 + if idx < len(lines): + return re.search(r"^(\s*)\S", lines[idx]).groups()[0] + return "" + + +def blackify(code): + """ + Applies the black part of our `make style` command to `code`. + """ + has_indent = len(get_indent(code)) > 0 + if has_indent: + code = f"class Bla:\n{code}" + mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119, preview=True) + result = black.format_str(code, mode=mode) + result, _ = style_docstrings_in_code(result) + return result[len("class Bla:\n") :] if has_indent else result + + +def is_copy_consistent(filename, overwrite=False): + """ + Check if the code commented as a copy in `filename` matches the original. + + Return the differences or overwrites the content depending on `overwrite`. + """ + with open(filename, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + diffs = [] + line_index = 0 + # Not a for loop cause `lines` is going to change (if `overwrite=True`). + while line_index < len(lines): + search = _re_copy_warning.search(lines[line_index]) + if search is None: + line_index += 1 + continue + + # There is some copied code here, let's retrieve the original. + indent, object_name, replace_pattern = search.groups() + theoretical_code = find_code_in_transformers(object_name) + theoretical_indent = get_indent(theoretical_code) + + start_index = line_index + 1 if indent == theoretical_indent else line_index + 2 + indent = theoretical_indent + line_index = start_index + + # Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment. + should_continue = True + while line_index < len(lines) and should_continue: + line_index += 1 + if line_index >= len(lines): + break + line = lines[line_index] + should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None + # Clean up empty lines at the end (if any). + while len(lines[line_index - 1]) <= 1: + line_index -= 1 + + observed_code_lines = lines[start_index:line_index] + observed_code = "".join(observed_code_lines) + + # Before comparing, use the `replace_pattern` on the original code. + if len(replace_pattern) > 0: + patterns = replace_pattern.replace("with", "").split(",") + patterns = [_re_replace_pattern.search(p) for p in patterns] + for pattern in patterns: + if pattern is None: + continue + obj1, obj2, option = pattern.groups() + theoretical_code = re.sub(obj1, obj2, theoretical_code) + if option.strip() == "all-casing": + theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) + theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) + + # Blackify after replacement. To be able to do that, we need the header (class or function definition) + # from the previous line + theoretical_code = blackify(lines[start_index - 1] + theoretical_code) + theoretical_code = theoretical_code[len(lines[start_index - 1]) :] + + # Test for a diff and act accordingly. + if observed_code != theoretical_code: + diffs.append([object_name, start_index]) + if overwrite: + lines = lines[:start_index] + [theoretical_code] + lines[line_index:] + line_index = start_index + 1 + + if overwrite and len(diffs) > 0: + # Warn the user a file has been modified. + print(f"Detected changes, rewriting {filename}.") + with open(filename, "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines) + return diffs + + +def check_copies(overwrite: bool = False): + all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) + diffs = [] + for filename in all_files: + new_diffs = is_copy_consistent(filename, overwrite) + diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs] + if not overwrite and len(diffs) > 0: + diff = "\n".join(diffs) + raise Exception( + "Found the following copy inconsistencies:\n" + + diff + + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." + ) + check_model_list_copy(overwrite=overwrite) + + +def check_full_copies(overwrite: bool = False): + diffs = [] + for target, source in FULL_COPIES.items(): + with open(source, "r", encoding="utf-8") as f: + source_code = f.read() + with open(target, "r", encoding="utf-8") as f: + target_code = f.read() + if source_code != target_code: + if overwrite: + with open(target, "w", encoding="utf-8") as f: + print(f"Replacing the content of {target} by the one of {source}.") + f.write(source_code) + else: + diffs.append(f"- {target}: copy does not match {source}.") + + if not overwrite and len(diffs) > 0: + diff = "\n".join(diffs) + raise Exception( + "Found the following copy inconsistencies:\n" + + diff + + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." + ) + + +def get_model_list(filename, start_prompt, end_prompt): + """Extracts the model list from the README.""" + with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + # Find the start of the list. + start_index = 0 + while not lines[start_index].startswith(start_prompt): + start_index += 1 + start_index += 1 + + result = [] + current_line = "" + end_index = start_index + + while not lines[end_index].startswith(end_prompt): + if lines[end_index].startswith("1."): + if len(current_line) > 1: + result.append(current_line) + current_line = lines[end_index] + elif len(lines[end_index]) > 1: + current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}" + end_index += 1 + if len(current_line) > 1: + result.append(current_line) + + return "".join(result) + + +def convert_to_localized_md(model_list, localized_model_list, format_str): + """Convert `model_list` to each localized README.""" + + def _rep(match): + title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups() + return format_str.format( + title=title, + model_link=model_link, + paper_affiliations=paper_affiliations, + paper_title_link=paper_title_link, + paper_authors=paper_authors, + supplements=" " + supplements.strip() if len(supplements) != 0 else "", + ) + + # This regex captures metadata from an English model description, including model title, model link, + # affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example). + _re_capture_meta = re.compile( + r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$" + ) + # This regex is used to synchronize link. + _re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*") + + if len(localized_model_list) == 0: + localized_model_index = {} + else: + try: + localized_model_index = { + re.search(r"\*\*\[([^\]]*)", line).groups()[0]: line + for line in localized_model_list.strip().split("\n") + } + except AttributeError: + raise AttributeError("A model name in localized READMEs cannot be recognized.") + + model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")] + + # We exclude keys in localized README not in the main one. + readmes_match = not any([k not in model_keys for k in localized_model_index]) + localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys} + + for model in model_list.strip().split("\n"): + title, model_link = _re_capture_title_link.search(model).groups() + if title not in localized_model_index: + readmes_match = False + # Add an anchor white space behind a model description string for regex. + # If metadata cannot be captured, the English version will be directly copied. + localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ") + else: + # Synchronize link + localized_model_index[title] = _re_capture_title_link.sub( + f"**[{title}]({model_link})**", localized_model_index[title], count=1 + ) + + sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower()) + + return readmes_match, "\n".join(map(lambda x: x[1], sorted_index)) + "\n" + + +def convert_readme_to_index(model_list): + model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "") + return model_list.replace("https://huggingface.co/docs/transformers/", "") + + +def _find_text_in_file(filename, start_prompt, end_prompt): + """ + Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty + lines. + """ + with open(filename, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + # Find the start prompt. + start_index = 0 + while not lines[start_index].startswith(start_prompt): + start_index += 1 + start_index += 1 + + end_index = start_index + while not lines[end_index].startswith(end_prompt): + end_index += 1 + end_index -= 1 + + while len(lines[start_index]) <= 1: + start_index += 1 + while len(lines[end_index]) <= 1: + end_index -= 1 + end_index += 1 + return "".join(lines[start_index:end_index]), start_index, end_index, lines + + +def check_model_list_copy(overwrite=False, max_per_line=119): + """Check the model lists in the README and index.rst are consistent and maybe `overwrite`.""" + # Fix potential doc links in the README + with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f: + readme = f.read() + new_readme = readme.replace("https://huggingface.co/transformers", "https://huggingface.co/docs/transformers") + new_readme = new_readme.replace( + "https://huggingface.co/docs/main/transformers", "https://huggingface.co/docs/transformers/main" + ) + if new_readme != readme: + if overwrite: + with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f: + f.write(new_readme) + else: + raise ValueError( + "The main README contains wrong links to the documentation of Transformers. Run `make fix-copies` to " + "automatically fix them." + ) + + # If the introduction or the conclusion of the list change, the prompts may need to be updated. + index_list, start_index, end_index, lines = _find_text_in_file( + filename=os.path.join(PATH_TO_DOCS, "index.mdx"), + start_prompt="", + ) + new_table = get_model_table_from_auto_modules() + + if current_table != new_table: + if overwrite: + with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines[:start_index] + [new_table] + lines[end_index:]) + else: + raise ValueError( + "The model table in the `index.mdx` has not been updated. Run `make fix-copies` to fix this." + ) + + +def has_onnx(model_type): + """ + Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not. + """ + config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING + if model_type not in config_mapping: + return False + config = config_mapping[model_type] + config_module = config.__module__ + module = transformers_module + for part in config_module.split(".")[1:]: + module = getattr(module, part) + config_name = config.__name__ + onnx_config_name = config_name.replace("Config", "OnnxConfig") + return hasattr(module, onnx_config_name) + + +def get_onnx_model_list(): + """ + Return the list of models supporting ONNX. + """ + config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING + model_names = config_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING + onnx_model_types = [model_type for model_type in config_mapping.keys() if has_onnx(model_type)] + onnx_model_names = [model_names[model_type] for model_type in onnx_model_types] + onnx_model_names.sort(key=lambda x: x.upper()) + return "\n".join([f"- {name}" for name in onnx_model_names]) + "\n" + + +def check_onnx_model_list(overwrite=False): + """Check the model list in the serialization.mdx is consistent with the state of the lib and maybe `overwrite`.""" + current_list, start_index, end_index, lines = _find_text_in_file( + filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"), + start_prompt="", + end_prompt="In the next two sections, we'll show you how to:", + ) + new_list = get_onnx_model_list() + + if current_list != new_list: + if overwrite: + with open(os.path.join(PATH_TO_DOCS, "serialization.mdx"), "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines[:start_index] + [new_list] + lines[end_index:]) + else: + raise ValueError("The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") + args = parser.parse_args() + + check_model_table(args.fix_and_overwrite) + check_onnx_model_list(args.fix_and_overwrite) diff --git a/utils/check_tf_ops.py b/utils/check_tf_ops.py new file mode 100644 index 00000000..f6c2b8ba --- /dev/null +++ b/utils/check_tf_ops.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +from tensorflow.core.protobuf.saved_model_pb2 import SavedModel + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_copies.py +REPO_PATH = "." + +# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model) +INTERNAL_OPS = [ + "Assert", + "AssignVariableOp", + "EmptyTensorList", + "MergeV2Checkpoints", + "ReadVariableOp", + "ResourceGather", + "RestoreV2", + "SaveV2", + "ShardedFilename", + "StatefulPartitionedCall", + "StaticRegexFullMatch", + "VarHandleOp", +] + + +def onnx_compliancy(saved_model_path, strict, opset): + saved_model = SavedModel() + onnx_ops = [] + + with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f: + onnx_opsets = json.load(f)["opsets"] + + for i in range(1, opset + 1): + onnx_ops.extend(onnx_opsets[str(i)]) + + with open(saved_model_path, "rb") as f: + saved_model.ParseFromString(f.read()) + + model_op_names = set() + + # Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs) + for meta_graph in saved_model.meta_graphs: + # Add operations in the graph definition + model_op_names.update(node.op for node in meta_graph.graph_def.node) + + # Go through the functions in the graph definition + for func in meta_graph.graph_def.library.function: + # Add operations in each function + model_op_names.update(node.op for node in func.node_def) + + # Convert to list, sorted if you want + model_op_names = sorted(model_op_names) + incompatible_ops = [] + + for op in model_op_names: + if op not in onnx_ops and op not in INTERNAL_OPS: + incompatible_ops.append(op) + + if strict and len(incompatible_ops) > 0: + raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops) + elif len(incompatible_ops) > 0: + print(f"Found the following incompatible ops for the opset {opset}:") + print(*incompatible_ops, sep="\n") + else: + print(f"The saved model {saved_model_path} can properly be converted with ONNX.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).") + parser.add_argument( + "--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested." + ) + parser.add_argument( + "--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model." + ) + parser.add_argument( + "--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)" + ) + args = parser.parse_args() + + if args.framework == "onnx": + onnx_compliancy(args.saved_model_path, args.strict, args.opset)