diff --git a/.github/workflows/pr_quality.yml b/.github/workflows/pr_quality.yml index 5a850bea..8d6e20ef 100644 --- a/.github/workflows/pr_quality.yml +++ b/.github/workflows/pr_quality.yml @@ -31,3 +31,20 @@ jobs: isort --check-only examples tests src utils scripts flake8 examples tests src utils scripts doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source + + check_repository_consistency: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.7" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check quality + run: | + python utils/check_copies.py + python utils/check_dummies.py diff --git a/Makefile b/Makefile index 6e513e2e..ea7537f2 100644 --- a/Makefile +++ b/Makefile @@ -67,6 +67,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency # Make marked copies of snippets of codes conform to the original fix-copies: + python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite # Run tests for the library diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index b94cada4..58e40b57 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -28,6 +28,7 @@ from .scheduling_utils import SchedulerMixin @dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 2a5c1c39..768413e9 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -26,6 +26,7 @@ from .scheduling_utils import SchedulerMixin @dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete class LMSDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. diff --git a/tests/repo_utils/test_check_copies.py b/tests/repo_utils/test_check_copies.py new file mode 100644 index 00000000..65128f68 --- /dev/null +++ b/tests/repo_utils/test_check_copies.py @@ -0,0 +1,120 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import shutil +import sys +import tempfile +import unittest + +import black + + +git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +sys.path.append(os.path.join(git_repo_path, "utils")) + +import check_copies # noqa: E402 + + +# This is the reference code that will be used in the tests. +# If DDPMSchedulerOutput is changed in scheduling_ddpm.py, this code needs to be manually updated. +REFERENCE_CODE = """ \""" + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + \""" + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None +""" + + +class CopyCheckTester(unittest.TestCase): + def setUp(self): + self.diffusers_dir = tempfile.mkdtemp() + os.makedirs(os.path.join(self.diffusers_dir, "schedulers/")) + check_copies.DIFFUSERS_PATH = self.diffusers_dir + shutil.copy( + os.path.join(git_repo_path, "src/diffusers/schedulers/scheduling_ddpm.py"), + os.path.join(self.diffusers_dir, "schedulers/scheduling_ddpm.py"), + ) + + def tearDown(self): + check_copies.DIFFUSERS_PATH = "src/diffusers" + shutil.rmtree(self.diffusers_dir) + + def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None): + code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code + if overwrite_result is not None: + expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result + mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119) + code = black.format_str(code, mode=mode) + fname = os.path.join(self.diffusers_dir, "new_code.py") + with open(fname, "w", newline="\n") as f: + f.write(code) + if overwrite_result is None: + self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0) + else: + check_copies.is_copy_consistent(f.name, overwrite=True) + with open(fname, "r") as f: + self.assertTrue(f.read(), expected) + + def test_find_code_in_diffusers(self): + code = check_copies.find_code_in_diffusers("schedulers.scheduling_ddpm.DDPMSchedulerOutput") + self.assertEqual(code, REFERENCE_CODE) + + def test_is_copy_consistent(self): + # Base copy consistency + self.check_copy_consistency( + "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput", + "DDPMSchedulerOutput", + REFERENCE_CODE + "\n", + ) + + # With no empty line at the end + self.check_copy_consistency( + "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput", + "DDPMSchedulerOutput", + REFERENCE_CODE, + ) + + # Copy consistency with rename + self.check_copy_consistency( + "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test", + "TestSchedulerOutput", + re.sub("DDPM", "Test", REFERENCE_CODE), + ) + + # Copy consistency with a really long name + long_class_name = "TestClassWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason" + self.check_copy_consistency( + f"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->{long_class_name}", + f"{long_class_name}SchedulerOutput", + re.sub("Bert", long_class_name, REFERENCE_CODE), + ) + + # Copy consistency with overwrite + self.check_copy_consistency( + "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test", + "TestSchedulerOutput", + REFERENCE_CODE, + overwrite_result=re.sub("DDPM", "Test", REFERENCE_CODE), + ) diff --git a/tests/repo_utils/test_check_dummies.py b/tests/repo_utils/test_check_dummies.py new file mode 100644 index 00000000..d8fa9ce1 --- /dev/null +++ b/tests/repo_utils/test_check_dummies.py @@ -0,0 +1,124 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + + +git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +sys.path.append(os.path.join(git_repo_path, "utils")) + +import check_dummies +from check_dummies import create_dummy_files, create_dummy_object, find_backend, read_init # noqa: E402 + + +# Align TRANSFORMERS_PATH in check_dummies with the current path +check_dummies.PATH_TO_DIFFUSERS = os.path.join(git_repo_path, "src", "diffusers") + + +class CheckDummiesTester(unittest.TestCase): + def test_find_backend(self): + simple_backend = find_backend(" if not is_torch_available():") + self.assertEqual(simple_backend, "torch") + + # backend_with_underscore = find_backend(" if not is_tensorflow_text_available():") + # self.assertEqual(backend_with_underscore, "tensorflow_text") + + double_backend = find_backend(" if not (is_torch_available() and is_transformers_available()):") + self.assertEqual(double_backend, "torch_and_transformers") + + # double_backend_with_underscore = find_backend( + # " if not (is_sentencepiece_available() and is_tensorflow_text_available()):" + # ) + # self.assertEqual(double_backend_with_underscore, "sentencepiece_and_tensorflow_text") + + triple_backend = find_backend( + " if not (is_torch_available() and is_transformers_available() and is_onnx_available()):" + ) + self.assertEqual(triple_backend, "torch_and_transformers_and_onnx") + + def test_read_init(self): + objects = read_init() + # We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects + self.assertIn("torch", objects) + self.assertIn("torch_and_transformers", objects) + self.assertIn("flax_and_transformers", objects) + self.assertIn("torch_and_transformers_and_onnx", objects) + + # Likewise, we can't assert on the exact content of a key + self.assertIn("UNet2DModel", objects["torch"]) + self.assertIn("FlaxUNet2DConditionModel", objects["flax"]) + self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"]) + self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"]) + self.assertIn("LMSDiscreteScheduler", objects["torch_and_scipy"]) + self.assertIn("OnnxStableDiffusionPipeline", objects["torch_and_transformers_and_onnx"]) + + def test_create_dummy_object(self): + dummy_constant = create_dummy_object("CONSTANT", "'torch'") + self.assertEqual(dummy_constant, "\nCONSTANT = None\n") + + dummy_function = create_dummy_object("function", "'torch'") + self.assertEqual( + dummy_function, "\ndef function(*args, **kwargs):\n requires_backends(function, 'torch')\n" + ) + + expected_dummy_class = """ +class FakeClass(metaclass=DummyObject): + _backends = 'torch' + + def __init__(self, *args, **kwargs): + requires_backends(self, 'torch') + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, 'torch') + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, 'torch') +""" + dummy_class = create_dummy_object("FakeClass", "'torch'") + self.assertEqual(dummy_class, expected_dummy_class) + + def test_create_dummy_files(self): + expected_dummy_pytorch_file = """# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +CONSTANT = None + + +def function(*args, **kwargs): + requires_backends(function, ["torch"]) + + +class FakeClass(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) +""" + dummy_files = create_dummy_files({"torch": ["CONSTANT", "function", "FakeClass"]}) + self.assertEqual(dummy_files["torch"], expected_dummy_pytorch_file) diff --git a/utils/check_copies.py b/utils/check_copies.py index 50f02cac..395cefb9 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. +# Copyright 2022 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ import argparse import glob +import importlib.util import os import re @@ -24,52 +25,17 @@ from doc_builder.style_doc import style_docstrings_in_code # All paths are set with the intent you should run this script from the root of the repo with the command # python utils/check_copies.py -TRANSFORMERS_PATH = "src/diffusers" -PATH_TO_DOCS = "docs/source/en" +DIFFUSERS_PATH = "src/diffusers" REPO_PATH = "." -# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with) -FULL_COPIES = { - "examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py", - "examples/flax/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py", -} - -LOCALIZED_READMES = { - # If the introduction or the conclusion of the list change, the prompts may need to be updated. - "README.md": { - "start_prompt": "🤗 Transformers currently provides the following architectures", - "end_prompt": "1. Want to contribute a new model?", - "format_model_list": ( - "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" - " {paper_authors}.{supplements}" - ), - }, - "README_zh-hans.md": { - "start_prompt": "🤗 Transformers 目前支持如下的架构", - "end_prompt": "1. 想要贡献新的模型?", - "format_model_list": ( - "**[{title}]({model_link})** (来自 {paper_affiliations}) 伴随论文 {paper_title_link} 由 {paper_authors}" - " 发布。{supplements}" - ), - }, - "README_zh-hant.md": { - "start_prompt": "🤗 Transformers 目前支援以下的架構", - "end_prompt": "1. 想要貢獻新的模型?", - "format_model_list": ( - "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" - " {paper_authors}.{supplements}" - ), - }, - "README_ko.md": { - "start_prompt": "🤗 Transformers는 다음 모델들을 제공합니다", - "end_prompt": "1. 새로운 모델을 올리고 싶나요?", - "format_model_list": ( - "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by" - " {paper_authors}.{supplements}" - ), - }, -} +# This is to make sure the diffusers module imported is the one in the repo. +spec = importlib.util.spec_from_file_location( + "diffusers", + os.path.join(DIFFUSERS_PATH, "__init__.py"), + submodule_search_locations=[DIFFUSERS_PATH], +) +diffusers_module = spec.loader.load_module() def _should_continue(line, indent): @@ -83,14 +49,14 @@ def find_code_in_diffusers(object_name): # First let's find the module where our object lives. module = parts[i] - while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")): + while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")): i += 1 if i < len(parts): module = os.path.join(module, parts[i]) if i >= len(parts): raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.") - with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: + with open(os.path.join(DIFFUSERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() # Now let's find the class / func in the code! @@ -121,6 +87,7 @@ def find_code_in_diffusers(object_name): _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)") _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") +_re_fill_pattern = re.compile(r"]*>") def get_indent(code): @@ -140,7 +107,7 @@ def blackify(code): has_indent = len(get_indent(code)) > 0 if has_indent: code = f"class Bla:\n{code}" - mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119, preview=True) + mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=119, preview=True) result = black.format_str(code, mode=mode) result, _ = style_docstrings_in_code(result) return result[len("class Bla:\n") :] if has_indent else result @@ -149,7 +116,6 @@ def blackify(code): def is_copy_consistent(filename, overwrite=False): """ Check if the code commented as a copy in `filename` matches the original. - Return the differences or overwrites the content depending on `overwrite`. """ with open(filename, "r", encoding="utf-8", newline="\n") as f: @@ -221,7 +187,7 @@ def is_copy_consistent(filename, overwrite=False): def check_copies(overwrite: bool = False): - all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) + all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True) diffs = [] for filename in all_files: new_diffs = is_copy_consistent(filename, overwrite) @@ -235,224 +201,9 @@ def check_copies(overwrite: bool = False): ) -# check_model_list_copy(overwrite=overwrite) - - -def check_full_copies(overwrite: bool = False): - diffs = [] - for target, source in FULL_COPIES.items(): - with open(source, "r", encoding="utf-8") as f: - source_code = f.read() - with open(target, "r", encoding="utf-8") as f: - target_code = f.read() - if source_code != target_code: - if overwrite: - with open(target, "w", encoding="utf-8") as f: - print(f"Replacing the content of {target} by the one of {source}.") - f.write(source_code) - else: - diffs.append(f"- {target}: copy does not match {source}.") - - if not overwrite and len(diffs) > 0: - diff = "\n".join(diffs) - raise Exception( - "Found the following copy inconsistencies:\n" - + diff - + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." - ) - - -def get_model_list(filename, start_prompt, end_prompt): - """Extracts the model list from the README.""" - with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f: - lines = f.readlines() - # Find the start of the list. - start_index = 0 - while not lines[start_index].startswith(start_prompt): - start_index += 1 - start_index += 1 - - result = [] - current_line = "" - end_index = start_index - - while not lines[end_index].startswith(end_prompt): - if lines[end_index].startswith("1."): - if len(current_line) > 1: - result.append(current_line) - current_line = lines[end_index] - elif len(lines[end_index]) > 1: - current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}" - end_index += 1 - if len(current_line) > 1: - result.append(current_line) - - return "".join(result) - - -def convert_to_localized_md(model_list, localized_model_list, format_str): - """Convert `model_list` to each localized README.""" - - def _rep(match): - title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups() - return format_str.format( - title=title, - model_link=model_link, - paper_affiliations=paper_affiliations, - paper_title_link=paper_title_link, - paper_authors=paper_authors, - supplements=" " + supplements.strip() if len(supplements) != 0 else "", - ) - - # This regex captures metadata from an English model description, including model title, model link, - # affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example). - _re_capture_meta = re.compile( - r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$" - ) - # This regex is used to synchronize link. - _re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*") - - if len(localized_model_list) == 0: - localized_model_index = {} - else: - try: - localized_model_index = { - re.search(r"\*\*\[([^\]]*)", line).groups()[0]: line - for line in localized_model_list.strip().split("\n") - } - except AttributeError: - raise AttributeError("A model name in localized READMEs cannot be recognized.") - - model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")] - - # We exclude keys in localized README not in the main one. - readmes_match = not any([k not in model_keys for k in localized_model_index]) - localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys} - - for model in model_list.strip().split("\n"): - title, model_link = _re_capture_title_link.search(model).groups() - if title not in localized_model_index: - readmes_match = False - # Add an anchor white space behind a model description string for regex. - # If metadata cannot be captured, the English version will be directly copied. - localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ") - else: - # Synchronize link - localized_model_index[title] = _re_capture_title_link.sub( - f"**[{title}]({model_link})**", localized_model_index[title], count=1 - ) - - sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower()) - - return readmes_match, "\n".join(map(lambda x: x[1], sorted_index)) + "\n" - - -def convert_readme_to_index(model_list): - model_list = model_list.replace("https://huggingface.co/docs/diffusers/main/", "") - return model_list.replace("https://huggingface.co/docs/diffusers/", "") - - -def _find_text_in_file(filename, start_prompt, end_prompt): - """ - Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty - lines. - """ - with open(filename, "r", encoding="utf-8", newline="\n") as f: - lines = f.readlines() - # Find the start prompt. - start_index = 0 - while not lines[start_index].startswith(start_prompt): - start_index += 1 - start_index += 1 - - end_index = start_index - while not lines[end_index].startswith(end_prompt): - end_index += 1 - end_index -= 1 - - while len(lines[start_index]) <= 1: - start_index += 1 - while len(lines[end_index]) <= 1: - end_index -= 1 - end_index += 1 - return "".join(lines[start_index:end_index]), start_index, end_index, lines - - -def check_model_list_copy(overwrite=False, max_per_line=119): - """Check the model lists in the README and index.rst are consistent and maybe `overwrite`.""" - # Fix potential doc links in the README - with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f: - readme = f.read() - new_readme = readme.replace("https://huggingface.co/diffusers", "https://huggingface.co/docs/diffusers") - new_readme = new_readme.replace( - "https://huggingface.co/docs/main/diffusers", "https://huggingface.co/docs/diffusers/main" - ) - if new_readme != readme: - if overwrite: - with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f: - f.write(new_readme) - else: - raise ValueError( - "The main README contains wrong links to the documentation of Transformers. Run `make fix-copies` to " - "automatically fix them." - ) - - # If the introduction or the conclusion of the list change, the prompts may need to be updated. - index_list, start_index, end_index, lines = _find_text_in_file( - filename=os.path.join(PATH_TO_DOCS, "index.mdx"), - start_prompt="