Introduce the copy mechanism (#924)
* Introduce the copy mechanism * init tests * fix dummy tests * with * update copies tests
This commit is contained in:
parent
cc36f2e7ff
commit
32bf4fdc43
|
@ -31,3 +31,20 @@ jobs:
|
|||
isort --check-only examples tests src utils scripts
|
||||
flake8 examples tests src utils scripts
|
||||
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
|
||||
check_repository_consistency:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
|
|
1
Makefile
1
Makefile
|
@ -67,6 +67,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
|
|||
# Make marked copies of snippets of codes conform to the original
|
||||
|
||||
fix-copies:
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
|
||||
# Run tests for the library
|
||||
|
|
|
@ -28,6 +28,7 @@ from .scheduling_utils import SchedulerMixin
|
|||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
||||
class DDIMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
|
|
@ -26,6 +26,7 @@ from .scheduling_utils import SchedulerMixin
|
|||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
|
||||
class LMSDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import black
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||
|
||||
import check_copies # noqa: E402
|
||||
|
||||
|
||||
# This is the reference code that will be used in the tests.
|
||||
# If DDPMSchedulerOutput is changed in scheduling_ddpm.py, this code needs to be manually updated.
|
||||
REFERENCE_CODE = """ \"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
\"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
"""
|
||||
|
||||
|
||||
class CopyCheckTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.diffusers_dir = tempfile.mkdtemp()
|
||||
os.makedirs(os.path.join(self.diffusers_dir, "schedulers/"))
|
||||
check_copies.DIFFUSERS_PATH = self.diffusers_dir
|
||||
shutil.copy(
|
||||
os.path.join(git_repo_path, "src/diffusers/schedulers/scheduling_ddpm.py"),
|
||||
os.path.join(self.diffusers_dir, "schedulers/scheduling_ddpm.py"),
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
check_copies.DIFFUSERS_PATH = "src/diffusers"
|
||||
shutil.rmtree(self.diffusers_dir)
|
||||
|
||||
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):
|
||||
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
||||
if overwrite_result is not None:
|
||||
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
|
||||
code = black.format_str(code, mode=mode)
|
||||
fname = os.path.join(self.diffusers_dir, "new_code.py")
|
||||
with open(fname, "w", newline="\n") as f:
|
||||
f.write(code)
|
||||
if overwrite_result is None:
|
||||
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
|
||||
else:
|
||||
check_copies.is_copy_consistent(f.name, overwrite=True)
|
||||
with open(fname, "r") as f:
|
||||
self.assertTrue(f.read(), expected)
|
||||
|
||||
def test_find_code_in_diffusers(self):
|
||||
code = check_copies.find_code_in_diffusers("schedulers.scheduling_ddpm.DDPMSchedulerOutput")
|
||||
self.assertEqual(code, REFERENCE_CODE)
|
||||
|
||||
def test_is_copy_consistent(self):
|
||||
# Base copy consistency
|
||||
self.check_copy_consistency(
|
||||
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput",
|
||||
"DDPMSchedulerOutput",
|
||||
REFERENCE_CODE + "\n",
|
||||
)
|
||||
|
||||
# With no empty line at the end
|
||||
self.check_copy_consistency(
|
||||
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput",
|
||||
"DDPMSchedulerOutput",
|
||||
REFERENCE_CODE,
|
||||
)
|
||||
|
||||
# Copy consistency with rename
|
||||
self.check_copy_consistency(
|
||||
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test",
|
||||
"TestSchedulerOutput",
|
||||
re.sub("DDPM", "Test", REFERENCE_CODE),
|
||||
)
|
||||
|
||||
# Copy consistency with a really long name
|
||||
long_class_name = "TestClassWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
|
||||
self.check_copy_consistency(
|
||||
f"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->{long_class_name}",
|
||||
f"{long_class_name}SchedulerOutput",
|
||||
re.sub("Bert", long_class_name, REFERENCE_CODE),
|
||||
)
|
||||
|
||||
# Copy consistency with overwrite
|
||||
self.check_copy_consistency(
|
||||
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test",
|
||||
"TestSchedulerOutput",
|
||||
REFERENCE_CODE,
|
||||
overwrite_result=re.sub("DDPM", "Test", REFERENCE_CODE),
|
||||
)
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||
|
||||
import check_dummies
|
||||
from check_dummies import create_dummy_files, create_dummy_object, find_backend, read_init # noqa: E402
|
||||
|
||||
|
||||
# Align TRANSFORMERS_PATH in check_dummies with the current path
|
||||
check_dummies.PATH_TO_DIFFUSERS = os.path.join(git_repo_path, "src", "diffusers")
|
||||
|
||||
|
||||
class CheckDummiesTester(unittest.TestCase):
|
||||
def test_find_backend(self):
|
||||
simple_backend = find_backend(" if not is_torch_available():")
|
||||
self.assertEqual(simple_backend, "torch")
|
||||
|
||||
# backend_with_underscore = find_backend(" if not is_tensorflow_text_available():")
|
||||
# self.assertEqual(backend_with_underscore, "tensorflow_text")
|
||||
|
||||
double_backend = find_backend(" if not (is_torch_available() and is_transformers_available()):")
|
||||
self.assertEqual(double_backend, "torch_and_transformers")
|
||||
|
||||
# double_backend_with_underscore = find_backend(
|
||||
# " if not (is_sentencepiece_available() and is_tensorflow_text_available()):"
|
||||
# )
|
||||
# self.assertEqual(double_backend_with_underscore, "sentencepiece_and_tensorflow_text")
|
||||
|
||||
triple_backend = find_backend(
|
||||
" if not (is_torch_available() and is_transformers_available() and is_onnx_available()):"
|
||||
)
|
||||
self.assertEqual(triple_backend, "torch_and_transformers_and_onnx")
|
||||
|
||||
def test_read_init(self):
|
||||
objects = read_init()
|
||||
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
|
||||
self.assertIn("torch", objects)
|
||||
self.assertIn("torch_and_transformers", objects)
|
||||
self.assertIn("flax_and_transformers", objects)
|
||||
self.assertIn("torch_and_transformers_and_onnx", objects)
|
||||
|
||||
# Likewise, we can't assert on the exact content of a key
|
||||
self.assertIn("UNet2DModel", objects["torch"])
|
||||
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
|
||||
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
|
||||
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
|
||||
self.assertIn("LMSDiscreteScheduler", objects["torch_and_scipy"])
|
||||
self.assertIn("OnnxStableDiffusionPipeline", objects["torch_and_transformers_and_onnx"])
|
||||
|
||||
def test_create_dummy_object(self):
|
||||
dummy_constant = create_dummy_object("CONSTANT", "'torch'")
|
||||
self.assertEqual(dummy_constant, "\nCONSTANT = None\n")
|
||||
|
||||
dummy_function = create_dummy_object("function", "'torch'")
|
||||
self.assertEqual(
|
||||
dummy_function, "\ndef function(*args, **kwargs):\n requires_backends(function, 'torch')\n"
|
||||
)
|
||||
|
||||
expected_dummy_class = """
|
||||
class FakeClass(metaclass=DummyObject):
|
||||
_backends = 'torch'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, 'torch')
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, 'torch')
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, 'torch')
|
||||
"""
|
||||
dummy_class = create_dummy_object("FakeClass", "'torch'")
|
||||
self.assertEqual(dummy_class, expected_dummy_class)
|
||||
|
||||
def test_create_dummy_files(self):
|
||||
expected_dummy_pytorch_file = """# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
CONSTANT = None
|
||||
|
||||
|
||||
def function(*args, **kwargs):
|
||||
requires_backends(function, ["torch"])
|
||||
|
||||
|
||||
class FakeClass(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
"""
|
||||
dummy_files = create_dummy_files({"torch": ["CONSTANT", "function", "FakeClass"]})
|
||||
self.assertEqual(dummy_files["torch"], expected_dummy_pytorch_file)
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,6 +15,7 @@
|
|||
|
||||
import argparse
|
||||
import glob
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
|
||||
|
@ -24,52 +25,17 @@ from doc_builder.style_doc import style_docstrings_in_code
|
|||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
TRANSFORMERS_PATH = "src/diffusers"
|
||||
PATH_TO_DOCS = "docs/source/en"
|
||||
DIFFUSERS_PATH = "src/diffusers"
|
||||
REPO_PATH = "."
|
||||
|
||||
# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
|
||||
FULL_COPIES = {
|
||||
"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py",
|
||||
"examples/flax/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py",
|
||||
}
|
||||
|
||||
|
||||
LOCALIZED_READMES = {
|
||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||
"README.md": {
|
||||
"start_prompt": "🤗 Transformers currently provides the following architectures",
|
||||
"end_prompt": "1. Want to contribute a new model?",
|
||||
"format_model_list": (
|
||||
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
|
||||
" {paper_authors}.{supplements}"
|
||||
),
|
||||
},
|
||||
"README_zh-hans.md": {
|
||||
"start_prompt": "🤗 Transformers 目前支持如下的架构",
|
||||
"end_prompt": "1. 想要贡献新的模型?",
|
||||
"format_model_list": (
|
||||
"**[{title}]({model_link})** (来自 {paper_affiliations}) 伴随论文 {paper_title_link} 由 {paper_authors}"
|
||||
" 发布。{supplements}"
|
||||
),
|
||||
},
|
||||
"README_zh-hant.md": {
|
||||
"start_prompt": "🤗 Transformers 目前支援以下的架構",
|
||||
"end_prompt": "1. 想要貢獻新的模型?",
|
||||
"format_model_list": (
|
||||
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
|
||||
" {paper_authors}.{supplements}"
|
||||
),
|
||||
},
|
||||
"README_ko.md": {
|
||||
"start_prompt": "🤗 Transformers는 다음 모델들을 제공합니다",
|
||||
"end_prompt": "1. 새로운 모델을 올리고 싶나요?",
|
||||
"format_model_list": (
|
||||
"**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
|
||||
" {paper_authors}.{supplements}"
|
||||
),
|
||||
},
|
||||
}
|
||||
# This is to make sure the diffusers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"diffusers",
|
||||
os.path.join(DIFFUSERS_PATH, "__init__.py"),
|
||||
submodule_search_locations=[DIFFUSERS_PATH],
|
||||
)
|
||||
diffusers_module = spec.loader.load_module()
|
||||
|
||||
|
||||
def _should_continue(line, indent):
|
||||
|
@ -83,14 +49,14 @@ def find_code_in_diffusers(object_name):
|
|||
|
||||
# First let's find the module where our object lives.
|
||||
module = parts[i]
|
||||
while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
|
||||
while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")):
|
||||
i += 1
|
||||
if i < len(parts):
|
||||
module = os.path.join(module, parts[i])
|
||||
if i >= len(parts):
|
||||
raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.")
|
||||
|
||||
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
with open(os.path.join(DIFFUSERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Now let's find the class / func in the code!
|
||||
|
@ -121,6 +87,7 @@ def find_code_in_diffusers(object_name):
|
|||
|
||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
||||
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
|
||||
|
||||
|
||||
def get_indent(code):
|
||||
|
@ -140,7 +107,7 @@ def blackify(code):
|
|||
has_indent = len(get_indent(code)) > 0
|
||||
if has_indent:
|
||||
code = f"class Bla:\n{code}"
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119, preview=True)
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=119, preview=True)
|
||||
result = black.format_str(code, mode=mode)
|
||||
result, _ = style_docstrings_in_code(result)
|
||||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
@ -149,7 +116,6 @@ def blackify(code):
|
|||
def is_copy_consistent(filename, overwrite=False):
|
||||
"""
|
||||
Check if the code commented as a copy in `filename` matches the original.
|
||||
|
||||
Return the differences or overwrites the content depending on `overwrite`.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
|
@ -221,7 +187,7 @@ def is_copy_consistent(filename, overwrite=False):
|
|||
|
||||
|
||||
def check_copies(overwrite: bool = False):
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True)
|
||||
diffs = []
|
||||
for filename in all_files:
|
||||
new_diffs = is_copy_consistent(filename, overwrite)
|
||||
|
@ -235,224 +201,9 @@ def check_copies(overwrite: bool = False):
|
|||
)
|
||||
|
||||
|
||||
# check_model_list_copy(overwrite=overwrite)
|
||||
|
||||
|
||||
def check_full_copies(overwrite: bool = False):
|
||||
diffs = []
|
||||
for target, source in FULL_COPIES.items():
|
||||
with open(source, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
with open(target, "r", encoding="utf-8") as f:
|
||||
target_code = f.read()
|
||||
if source_code != target_code:
|
||||
if overwrite:
|
||||
with open(target, "w", encoding="utf-8") as f:
|
||||
print(f"Replacing the content of {target} by the one of {source}.")
|
||||
f.write(source_code)
|
||||
else:
|
||||
diffs.append(f"- {target}: copy does not match {source}.")
|
||||
|
||||
if not overwrite and len(diffs) > 0:
|
||||
diff = "\n".join(diffs)
|
||||
raise Exception(
|
||||
"Found the following copy inconsistencies:\n"
|
||||
+ diff
|
||||
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
|
||||
|
||||
def get_model_list(filename, start_prompt, end_prompt):
|
||||
"""Extracts the model list from the README."""
|
||||
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start of the list.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
result = []
|
||||
current_line = ""
|
||||
end_index = start_index
|
||||
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
if lines[end_index].startswith("1."):
|
||||
if len(current_line) > 1:
|
||||
result.append(current_line)
|
||||
current_line = lines[end_index]
|
||||
elif len(lines[end_index]) > 1:
|
||||
current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}"
|
||||
end_index += 1
|
||||
if len(current_line) > 1:
|
||||
result.append(current_line)
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def convert_to_localized_md(model_list, localized_model_list, format_str):
|
||||
"""Convert `model_list` to each localized README."""
|
||||
|
||||
def _rep(match):
|
||||
title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups()
|
||||
return format_str.format(
|
||||
title=title,
|
||||
model_link=model_link,
|
||||
paper_affiliations=paper_affiliations,
|
||||
paper_title_link=paper_title_link,
|
||||
paper_authors=paper_authors,
|
||||
supplements=" " + supplements.strip() if len(supplements) != 0 else "",
|
||||
)
|
||||
|
||||
# This regex captures metadata from an English model description, including model title, model link,
|
||||
# affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example).
|
||||
_re_capture_meta = re.compile(
|
||||
r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$"
|
||||
)
|
||||
# This regex is used to synchronize link.
|
||||
_re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*")
|
||||
|
||||
if len(localized_model_list) == 0:
|
||||
localized_model_index = {}
|
||||
else:
|
||||
try:
|
||||
localized_model_index = {
|
||||
re.search(r"\*\*\[([^\]]*)", line).groups()[0]: line
|
||||
for line in localized_model_list.strip().split("\n")
|
||||
}
|
||||
except AttributeError:
|
||||
raise AttributeError("A model name in localized READMEs cannot be recognized.")
|
||||
|
||||
model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")]
|
||||
|
||||
# We exclude keys in localized README not in the main one.
|
||||
readmes_match = not any([k not in model_keys for k in localized_model_index])
|
||||
localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys}
|
||||
|
||||
for model in model_list.strip().split("\n"):
|
||||
title, model_link = _re_capture_title_link.search(model).groups()
|
||||
if title not in localized_model_index:
|
||||
readmes_match = False
|
||||
# Add an anchor white space behind a model description string for regex.
|
||||
# If metadata cannot be captured, the English version will be directly copied.
|
||||
localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ")
|
||||
else:
|
||||
# Synchronize link
|
||||
localized_model_index[title] = _re_capture_title_link.sub(
|
||||
f"**[{title}]({model_link})**", localized_model_index[title], count=1
|
||||
)
|
||||
|
||||
sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower())
|
||||
|
||||
return readmes_match, "\n".join(map(lambda x: x[1], sorted_index)) + "\n"
|
||||
|
||||
|
||||
def convert_readme_to_index(model_list):
|
||||
model_list = model_list.replace("https://huggingface.co/docs/diffusers/main/", "")
|
||||
return model_list.replace("https://huggingface.co/docs/diffusers/", "")
|
||||
|
||||
|
||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
"""
|
||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
||||
lines.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start prompt.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
end_index = start_index
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
end_index += 1
|
||||
end_index -= 1
|
||||
|
||||
while len(lines[start_index]) <= 1:
|
||||
start_index += 1
|
||||
while len(lines[end_index]) <= 1:
|
||||
end_index -= 1
|
||||
end_index += 1
|
||||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
|
||||
# Fix potential doc links in the README
|
||||
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
|
||||
readme = f.read()
|
||||
new_readme = readme.replace("https://huggingface.co/diffusers", "https://huggingface.co/docs/diffusers")
|
||||
new_readme = new_readme.replace(
|
||||
"https://huggingface.co/docs/main/diffusers", "https://huggingface.co/docs/diffusers/main"
|
||||
)
|
||||
if new_readme != readme:
|
||||
if overwrite:
|
||||
with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(new_readme)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main README contains wrong links to the documentation of Transformers. Run `make fix-copies` to "
|
||||
"automatically fix them."
|
||||
)
|
||||
|
||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||
index_list, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.mdx"),
|
||||
start_prompt="<!--This list is updated automatically from the README",
|
||||
end_prompt="### Supported frameworks",
|
||||
)
|
||||
md_list = get_model_list(
|
||||
filename="README.md",
|
||||
start_prompt=LOCALIZED_READMES["README.md"]["start_prompt"],
|
||||
end_prompt=LOCALIZED_READMES["README.md"]["end_prompt"],
|
||||
)
|
||||
|
||||
converted_md_lists = []
|
||||
for filename, value in LOCALIZED_READMES.items():
|
||||
_start_prompt = value["start_prompt"]
|
||||
_end_prompt = value["end_prompt"]
|
||||
_format_model_list = value["format_model_list"]
|
||||
|
||||
localized_md_list = get_model_list(filename, _start_prompt, _end_prompt)
|
||||
readmes_match, converted_md_list = convert_to_localized_md(md_list, localized_md_list, _format_model_list)
|
||||
|
||||
converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt))
|
||||
|
||||
converted_md_list = convert_readme_to_index(md_list)
|
||||
if converted_md_list != index_list:
|
||||
if overwrite:
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [converted_md_list] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model list in the README changed and the list in `index.mdx` has not been updated. Run "
|
||||
"`make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
for converted_md_list in converted_md_lists:
|
||||
filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list
|
||||
|
||||
if filename == "README.md":
|
||||
continue
|
||||
if overwrite:
|
||||
_, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(REPO_PATH, filename), start_prompt=_start_prompt, end_prompt=_end_prompt
|
||||
)
|
||||
with open(os.path.join(REPO_PATH, filename), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [converted_md] + lines[end_index:])
|
||||
elif not readmes_match:
|
||||
raise ValueError(
|
||||
f"The model list in the README changed and the list in `{filename}` has not been updated. Run "
|
||||
"`make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_copies(args.fix_and_overwrite)
|
||||
check_full_copies(args.fix_and_overwrite)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -106,9 +106,10 @@ def create_dummy_object(name, backend_name):
|
|||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files():
|
||||
def create_dummy_files(backend_specific_objects=None):
|
||||
"""Create the content of the dummy files."""
|
||||
backend_specific_objects = read_init()
|
||||
if backend_specific_objects is None:
|
||||
backend_specific_objects = read_init()
|
||||
# For special correspondence backend to module name as used in the function requires_modulename
|
||||
dummy_files = {}
|
||||
|
||||
|
|
Loading…
Reference in New Issue