save intermediate

This commit is contained in:
Patrick von Platen 2022-06-17 16:58:45 +02:00
parent 1997b90838
commit b4e6a7403d
7 changed files with 68 additions and 47 deletions

View File

@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
# Make marked copies of snippets of codes conform to the original
fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
python utils/check_copies.py --fix_and_overwrite
# Run tests for the library

View File

@ -1,15 +1,20 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
else:
from .utils.dummy_transformers_objects import *

View File

@ -2,15 +2,6 @@ from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_grad_tts import GradTTS
try:
from .pipeline_glide import GLIDE
except (NameError, ImportError):
class GLIDE:
pass
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_pndm import PNDM

View File

@ -1,12 +1,3 @@
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
import os
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -21,6 +12,10 @@ import os
# See the License for the specific language governing permissions and
# limitations under the License.
from requests.exceptions import HTTPError
import importlib
import importlib_metadata
import os
from .logging import logger
hf_cache_home = os.path.expanduser(
@ -36,6 +31,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
_transformers_available = importlib.util.find_spec("transformers") is not None
try:
_transformers_version = importlib_metadata.version("transformers")
logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
_transformers_available = False
def is_transformers_available():
return _transformers_available
class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does

View File

@ -0,0 +1,24 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class GLIDESuperResUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GLIDETextToImageUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GLIDEUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])

View File

@ -20,10 +20,10 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers"
PATH_TO_DIFFUSERS = "src/diffusers"
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
_re_backend = re.compile(r"if is\_([a-z_]*)_available\(\)")
# Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)")
@ -50,36 +50,30 @@ def {0}(*args, **kwargs):
def find_backend(line):
"""Find one (or multiple) backend in a code line of the init."""
if _re_test_backend.search(line) is None:
backends = _re_backend.findall(line)
if len(backends) == 0:
return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort()
return "_and_".join(backends)
return backends[0]
def read_init():
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Get to the point we do the actual imports for type checking
line_index = 0
while not lines[line_index].startswith("if TYPE_CHECKING"):
line_index += 1
backend_specific_objects = {}
# Go through the end of the file
while line_index < len(lines):
# If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
if backend is not None:
while not lines[line_index].startswith(" else:"):
line_index += 1
line_index += 1
objects = []
line_index += 1
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
while not lines[line_index].startswith("else:"):
line = lines[line_index]
single_line_import_search = _re_single_line_import.search(line)
if single_line_import_search is not None:
@ -129,7 +123,7 @@ def check_dummies(overwrite=False):
short_names = {"torch": "pt"}
# Locate actual dummy modules and read their content.
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
path = os.path.join(PATH_TO_DIFFUSERS, "utils")
dummy_file_paths = {
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
for backend in dummy_files.keys()
@ -147,7 +141,7 @@ def check_dummies(overwrite=False):
if dummy_files[backend] != actual_dummies[backend]:
if overwrite:
print(
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
"__init__ has new objects."
)
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
@ -155,7 +149,7 @@ def check_dummies(overwrite=False):
else:
raise ValueError(
"The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
"to fix this."
)

View File

@ -22,7 +22,7 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_table.py
TRANSFORMERS_PATH = "src/transformers"
TRANSFORMERS_PATH = "src/diffusers"
PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "."
@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# This is to make sure the transformers module imported is the one in the repo.
# This is to make sure the diffusers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
"diffusers",
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
submodule_search_locations=[TRANSFORMERS_PATH],
)
transformers_module = spec.loader.load_module()
diffusers_module = spec.loader.load_module()
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
@ -88,10 +88,10 @@ def _center_text(text, width):
def get_model_table_from_auto_modules():
"""Generates an up-to-date model table from the content of the auto modules."""
# Dictionary model names to config.
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
model_name_to_config = {
name: config_maping_names[code]
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
for code, name in diffusers_module.MODEL_NAMES_MAPPING.items()
if code in config_maping_names
}
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
tf_models = collections.defaultdict(bool)
flax_models = collections.defaultdict(bool)
# Let's lookup through all transformers object (once).
for attr_name in dir(transformers_module):
# Let's lookup through all diffusers object (once).
for attr_name in dir(diffusers_module):
lookup_dict = None
if attr_name.endswith("Tokenizer"):
lookup_dict = slow_tokenizers