diff --git a/Makefile b/Makefile index ddf143b6..ba718723 100644 --- a/Makefile +++ b/Makefile @@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency # Make marked copies of snippets of codes conform to the original fix-copies: - python utils/check_copies.py --fix_and_overwrite python utils/check_table.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite + python utils/check_copies.py --fix_and_overwrite # Run tests for the library diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index dc69d8bf..a8b5acf1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,15 +1,20 @@ # flake8: noqa # There's no way to ignore "F401 '...' imported but unused" warnings in this # module, but to preserve other warnings. So, don't check this module at all. +from .utils import is_transformers_available __version__ = "0.0.4" from .modeling_utils import ModelMixin from .models.unet import UNetModel -from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler + +if is_transformers_available(): + from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel +else: + from .utils.dummy_transformers_objects import * diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 91af443c..c0f737f9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -2,15 +2,6 @@ from .pipeline_bddm import BDDM from .pipeline_ddim import DDIM from .pipeline_ddpm import DDPM from .pipeline_grad_tts import GradTTS - - -try: - from .pipeline_glide import GLIDE -except (NameError, ImportError): - - class GLIDE: - pass - - +from .pipeline_glide import GLIDE from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_pndm import PNDM diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 7f25da44..407066fc 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -1,12 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - -# flake8: noqa -# There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all. - -import os - # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +12,10 @@ import os # See the License for the specific language governing permissions and # limitations under the License. from requests.exceptions import HTTPError +import importlib +import importlib_metadata +import os +from .logging import logger hf_cache_home = os.path.expanduser( @@ -36,6 +31,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +def is_transformers_available(): + return _transformers_available + + class RepositoryNotFoundError(HTTPError): """ Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_transformers_objects.py new file mode 100644 index 00000000..83578889 --- /dev/null +++ b/src/diffusers/utils/dummy_transformers_objects.py @@ -0,0 +1,24 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa +from ..utils import DummyObject, requires_backends + + +class GLIDESuperResUNetModel(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) + + +class GLIDETextToImageUNetModel(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) + + +class GLIDEUNetModel(metaclass=DummyObject): + _backends = ["transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers"]) diff --git a/utils/check_dummies.py b/utils/check_dummies.py index d6c1c4b5..e132b349 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -20,10 +20,10 @@ import re # All paths are set with the intent you should run this script from the root of the repo with the command # python utils/check_dummies.py -PATH_TO_TRANSFORMERS = "src/transformers" +PATH_TO_DIFFUSERS = "src/diffusers" # Matches is_xxx_available() -_re_backend = re.compile(r"is\_([a-z_]*)_available()") +_re_backend = re.compile(r"if is\_([a-z_]*)_available\(\)") # Matches from xxx import bla _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") _re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)") @@ -50,36 +50,30 @@ def {0}(*args, **kwargs): def find_backend(line): """Find one (or multiple) backend in a code line of the init.""" - if _re_test_backend.search(line) is None: + backends = _re_backend.findall(line) + if len(backends) == 0: return None - backends = [b[0] for b in _re_backend.findall(line)] - backends.sort() - return "_and_".join(backends) + + return backends[0] def read_init(): """Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects.""" - with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: + with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() # Get to the point we do the actual imports for type checking line_index = 0 - while not lines[line_index].startswith("if TYPE_CHECKING"): - line_index += 1 - backend_specific_objects = {} # Go through the end of the file while line_index < len(lines): # If the line is an if is_backend_available, we grab all objects associated. backend = find_backend(lines[line_index]) if backend is not None: - while not lines[line_index].startswith(" else:"): - line_index += 1 - line_index += 1 - objects = [] + line_index += 1 # Until we unindent, add backend objects to the list - while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): + while not lines[line_index].startswith("else:"): line = lines[line_index] single_line_import_search = _re_single_line_import.search(line) if single_line_import_search is not None: @@ -129,7 +123,7 @@ def check_dummies(overwrite=False): short_names = {"torch": "pt"} # Locate actual dummy modules and read their content. - path = os.path.join(PATH_TO_TRANSFORMERS, "utils") + path = os.path.join(PATH_TO_DIFFUSERS, "utils") dummy_file_paths = { backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") for backend in dummy_files.keys() @@ -147,7 +141,7 @@ def check_dummies(overwrite=False): if dummy_files[backend] != actual_dummies[backend]: if overwrite: print( - f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main " + f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main " "__init__ has new objects." ) with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f: @@ -155,7 +149,7 @@ def check_dummies(overwrite=False): else: raise ValueError( "The main __init__ has objects that are not present in " - f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " + f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " "to fix this." ) diff --git a/utils/check_table.py b/utils/check_table.py index 3a900551..6c74308c 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -22,7 +22,7 @@ import re # All paths are set with the intent you should run this script from the root of the repo with the command # python utils/check_table.py -TRANSFORMERS_PATH = "src/transformers" +TRANSFORMERS_PATH = "src/diffusers" PATH_TO_DOCS = "docs/source/en" REPO_PATH = "." @@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") -# This is to make sure the transformers module imported is the one in the repo. +# This is to make sure the diffusers module imported is the one in the repo. spec = importlib.util.spec_from_file_location( - "transformers", + "diffusers", os.path.join(TRANSFORMERS_PATH, "__init__.py"), submodule_search_locations=[TRANSFORMERS_PATH], ) -transformers_module = spec.loader.load_module() +diffusers_module = spec.loader.load_module() # Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python @@ -88,10 +88,10 @@ def _center_text(text, width): def get_model_table_from_auto_modules(): """Generates an up-to-date model table from the content of the auto modules.""" # Dictionary model names to config. - config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES + config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES model_name_to_config = { name: config_maping_names[code] - for code, name in transformers_module.MODEL_NAMES_MAPPING.items() + for code, name in diffusers_module.MODEL_NAMES_MAPPING.items() if code in config_maping_names } model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()} @@ -103,8 +103,8 @@ def get_model_table_from_auto_modules(): tf_models = collections.defaultdict(bool) flax_models = collections.defaultdict(bool) - # Let's lookup through all transformers object (once). - for attr_name in dir(transformers_module): + # Let's lookup through all diffusers object (once). + for attr_name in dir(diffusers_module): lookup_dict = None if attr_name.endswith("Tokenizer"): lookup_dict = slow_tokenizers