save intermediate
This commit is contained in:
parent
1997b90838
commit
b4e6a7403d
2
Makefile
2
Makefile
|
@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
|
|||
# Make marked copies of snippets of codes conform to the original
|
||||
|
||||
fix-copies:
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
python utils/check_table.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
|
||||
# Run tests for the library
|
||||
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
from .utils import is_transformers_available
|
||||
|
||||
__version__ = "0.0.4"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
from .models.unet_grad_tts import UNetGradTTSModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
if is_transformers_available():
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
else:
|
||||
from .utils.dummy_transformers_objects import *
|
||||
|
|
|
@ -2,15 +2,6 @@ from .pipeline_bddm import BDDM
|
|||
from .pipeline_ddim import DDIM
|
||||
from .pipeline_ddpm import DDPM
|
||||
from .pipeline_grad_tts import GradTTS
|
||||
|
||||
|
||||
try:
|
||||
from .pipeline_glide import GLIDE
|
||||
except (NameError, ImportError):
|
||||
|
||||
class GLIDE:
|
||||
pass
|
||||
|
||||
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
from .pipeline_pndm import PNDM
|
||||
|
|
|
@ -1,12 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
import os
|
||||
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -21,6 +12,10 @@ import os
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from requests.exceptions import HTTPError
|
||||
import importlib
|
||||
import importlib_metadata
|
||||
import os
|
||||
from .logging import logger
|
||||
|
||||
|
||||
hf_cache_home = os.path.expanduser(
|
||||
|
@ -36,6 +31,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
|||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
_transformers_version = importlib_metadata.version("transformers")
|
||||
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
|
||||
def is_transformers_available():
|
||||
return _transformers_available
|
||||
|
||||
|
||||
class RepositoryNotFoundError(HTTPError):
|
||||
"""
|
||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class GLIDESuperResUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class GLIDETextToImageUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class GLIDEUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
|
@ -20,10 +20,10 @@ import re
|
|||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_dummies.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
PATH_TO_DIFFUSERS = "src/diffusers"
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
_re_backend = re.compile(r"if is\_([a-z_]*)_available\(\)")
|
||||
# Matches from xxx import bla
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)")
|
||||
|
@ -50,36 +50,30 @@ def {0}(*args, **kwargs):
|
|||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
if _re_test_backend.search(line) is None:
|
||||
backends = _re_backend.findall(line)
|
||||
if len(backends) == 0:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
return backends[0]
|
||||
|
||||
|
||||
def read_init():
|
||||
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get to the point we do the actual imports for type checking
|
||||
line_index = 0
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
line_index += 1
|
||||
|
||||
backend_specific_objects = {}
|
||||
# Go through the end of the file
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
while not lines[line_index].startswith(" else:"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
line_index += 1
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
while not lines[line_index].startswith("else:"):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_single_line_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
|
@ -129,7 +123,7 @@ def check_dummies(overwrite=False):
|
|||
short_names = {"torch": "pt"}
|
||||
|
||||
# Locate actual dummy modules and read their content.
|
||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||
path = os.path.join(PATH_TO_DIFFUSERS, "utils")
|
||||
dummy_file_paths = {
|
||||
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
|
||||
for backend in dummy_files.keys()
|
||||
|
@ -147,7 +141,7 @@ def check_dummies(overwrite=False):
|
|||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
if overwrite:
|
||||
print(
|
||||
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
"__init__ has new objects."
|
||||
)
|
||||
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
|
||||
|
@ -155,7 +149,7 @@ def check_dummies(overwrite=False):
|
|||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in "
|
||||
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
|
||||
f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
|
||||
"to fix this."
|
||||
)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import re
|
|||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_table.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
TRANSFORMERS_PATH = "src/diffusers"
|
||||
PATH_TO_DOCS = "docs/source/en"
|
||||
REPO_PATH = "."
|
||||
|
||||
|
@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
|
|||
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
# This is to make sure the diffusers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
"diffusers",
|
||||
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
||||
submodule_search_locations=[TRANSFORMERS_PATH],
|
||||
)
|
||||
transformers_module = spec.loader.load_module()
|
||||
diffusers_module = spec.loader.load_module()
|
||||
|
||||
|
||||
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||
|
@ -88,10 +88,10 @@ def _center_text(text, width):
|
|||
def get_model_table_from_auto_modules():
|
||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||
# Dictionary model names to config.
|
||||
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
model_name_to_config = {
|
||||
name: config_maping_names[code]
|
||||
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
|
||||
for code, name in diffusers_module.MODEL_NAMES_MAPPING.items()
|
||||
if code in config_maping_names
|
||||
}
|
||||
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
|
||||
|
@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
|
|||
tf_models = collections.defaultdict(bool)
|
||||
flax_models = collections.defaultdict(bool)
|
||||
|
||||
# Let's lookup through all transformers object (once).
|
||||
for attr_name in dir(transformers_module):
|
||||
# Let's lookup through all diffusers object (once).
|
||||
for attr_name in dir(diffusers_module):
|
||||
lookup_dict = None
|
||||
if attr_name.endswith("Tokenizer"):
|
||||
lookup_dict = slow_tokenizers
|
||||
|
|
Loading…
Reference in New Issue