From 721e017401ded6d11447c4bf964291853231e0d8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 13 Sep 2022 13:24:27 +0200 Subject: [PATCH] [Flax] Make room for more frameworks (#494) * start * finish --- setup.py | 18 +- src/diffusers/__init__.py | 64 +++---- src/diffusers/dependency_versions_table.py | 3 + src/diffusers/utils/dummy_pt_objects.py | 165 ++++++++++++++++++ ...ts.py => dummy_torch_and_scipy_objects.py} | 4 +- ...orch_and_transformers_and_onnx_objects.py} | 4 +- ...> dummy_torch_and_transformers_objects.py} | 16 +- ...rmers_and_inflect_and_unidecode_objects.py | 10 -- 8 files changed, 227 insertions(+), 57 deletions(-) create mode 100644 src/diffusers/utils/dummy_pt_objects.py rename src/diffusers/utils/{dummy_scipy_objects.py => dummy_torch_and_scipy_objects.py} (73%) rename src/diffusers/utils/{dummy_transformers_and_onnx_objects.py => dummy_torch_and_transformers_and_onnx_objects.py} (67%) rename src/diffusers/utils/{dummy_transformers_objects.py => dummy_torch_and_transformers_objects.py} (57%) delete mode 100644 src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py diff --git a/setup.py b/setup.py index fffb8f26..6a929d82 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ To create the package for pypi. """ import re +import os from distutils.core import Command from setuptools import find_packages, setup @@ -82,10 +83,13 @@ _deps = [ "datasets", "filelock", "flake8>=3.8.3", + "flax>=0.4.1", "hf-doc-builder>=0.3.0", "huggingface-hub>=0.8.1", "importlib_metadata", "isort>=5.5.4", + "jax>=0.2.8,!=0.3.2,<=0.3.6", + "jaxlib>=0.1.65,<=0.3.6", "modelcards==0.1.4", "numpy", "pytest", @@ -171,7 +175,14 @@ extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-bui extras["docs"] = ["hf-doc-builder"] extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"] extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"] -extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] +extras["torch"] = deps_list("torch") + +if os.name == "nt": # windows + extras["flax"] = [] # jax is not supported on windows +else: + extras["flax"] = deps_list("jax", "jaxlib", "flax") + +extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] install_requires = [ deps["importlib_metadata"], @@ -180,13 +191,12 @@ install_requires = [ deps["numpy"], deps["regex"], deps["requests"], - deps["torch"], deps["Pillow"], ] setup( name="diffusers", - version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="Diffusers", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", @@ -198,7 +208,7 @@ setup( package_dir={"": "src"}, packages=find_packages("src"), include_package_data=True, - python_requires=">=3.6.0", + python_requires=">=3.7.0", install_requires=install_requires, extras_require=extras, entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]}, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 14fb19ef..219f2d8b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -2,6 +2,7 @@ from .utils import ( is_inflect_available, is_onnx_available, is_scipy_available, + is_torch_available, is_transformers_available, is_unidecode_available, ) @@ -10,40 +11,42 @@ from .utils import ( __version__ = "0.4.0.dev0" from .configuration_utils import ConfigMixin -from .modeling_utils import ModelMixin -from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .onnx_utils import OnnxRuntimeModel -from .optimization import ( - get_constant_schedule, - get_constant_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_cosine_with_hard_restarts_schedule_with_warmup, - get_linear_schedule_with_warmup, - get_polynomial_decay_schedule_with_warmup, - get_scheduler, -) -from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline -from .schedulers import ( - DDIMScheduler, - DDPMScheduler, - KarrasVeScheduler, - PNDMScheduler, - SchedulerMixin, - ScoreSdeVeScheduler, -) from .utils import logging -if is_scipy_available(): +if is_torch_available(): + from .modeling_utils import ModelMixin + from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel + from .optimization import ( + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, + get_scheduler, + ) + from .pipeline_utils import DiffusionPipeline + from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline + from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + KarrasVeScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, + ) + from .training_utils import EMAModel +else: + from .utils.dummy_pt_objects import * # noqa F403 + +if is_torch_available() and is_scipy_available(): from .schedulers import LMSDiscreteScheduler else: - from .utils.dummy_scipy_objects import * # noqa F403 + from .utils.dummy_torch_and_scipy_objects import * # noqa F403 -from .training_utils import EMAModel - - -if is_transformers_available(): +if is_torch_available() and is_transformers_available(): from .pipelines import ( LDMTextToImagePipeline, StableDiffusionImg2ImgPipeline, @@ -51,10 +54,9 @@ if is_transformers_available(): StableDiffusionPipeline, ) else: - from .utils.dummy_transformers_objects import * # noqa F403 + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 - -if is_transformers_available() and is_onnx_available(): +if is_torch_available() and is_transformers_available() and is_onnx_available(): from .pipelines import StableDiffusionOnnxPipeline else: - from .utils.dummy_transformers_and_onnx_objects import * # noqa F403 + from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index fa8ddfe0..dffb5abc 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -8,10 +8,13 @@ deps = { "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", + "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", "huggingface-hub": "huggingface-hub>=0.8.1", "importlib_metadata": "importlib_metadata", "isort": "isort>=5.5.4", + "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", + "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards==0.1.4", "numpy": "numpy", "pytest": "pytest", diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py new file mode 100644 index 00000000..531c0b77 --- /dev/null +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -0,0 +1,165 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class ModelMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AutoencoderKL(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class UNet2DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class UNet2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class VQModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +def get_constant_schedule(*args, **kwargs): + requires_backends(get_constant_schedule, ["torch"]) + + +def get_constant_schedule_with_warmup(*args, **kwargs): + requires_backends(get_constant_schedule_with_warmup, ["torch"]) + + +def get_cosine_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_schedule_with_warmup, ["torch"]) + + +def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) + + +def get_linear_schedule_with_warmup(*args, **kwargs): + requires_backends(get_linear_schedule_with_warmup, ["torch"]) + + +def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): + requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"]) + + +def get_scheduler(*args, **kwargs): + requires_backends(get_scheduler, ["torch"]) + + +class DiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DDIMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DDPMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class KarrasVePipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LDMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PNDMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ScoreSdeVePipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DDPMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class KarrasVeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PNDMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SchedulerMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EMAModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) diff --git a/src/diffusers/utils/dummy_scipy_objects.py b/src/diffusers/utils/dummy_torch_and_scipy_objects.py similarity index 73% rename from src/diffusers/utils/dummy_scipy_objects.py rename to src/diffusers/utils/dummy_torch_and_scipy_objects.py index 3706c575..49c89564 100644 --- a/src/diffusers/utils/dummy_scipy_objects.py +++ b/src/diffusers/utils/dummy_torch_and_scipy_objects.py @@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends class LMSDiscreteScheduler(metaclass=DummyObject): - _backends = ["scipy"] + _backends = ["torch", "scipy"] def __init__(self, *args, **kwargs): - requires_backends(self, ["scipy"]) + requires_backends(self, ["torch", "scipy"]) diff --git a/src/diffusers/utils/dummy_transformers_and_onnx_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py similarity index 67% rename from src/diffusers/utils/dummy_transformers_and_onnx_objects.py rename to src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py index 2e34b5ce..967e231d 100644 --- a/src/diffusers/utils/dummy_transformers_and_onnx_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends class StableDiffusionOnnxPipeline(metaclass=DummyObject): - _backends = ["transformers", "onnx"] + _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers", "onnx"]) + requires_backends(self, ["torch", "transformers", "onnx"]) diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py similarity index 57% rename from src/diffusers/utils/dummy_transformers_objects.py rename to src/diffusers/utils/dummy_torch_and_transformers_objects.py index e05eb814..6e4ab48c 100644 --- a/src/diffusers/utils/dummy_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -5,28 +5,28 @@ from ..utils import DummyObject, requires_backends class LDMTextToImagePipeline(metaclass=DummyObject): - _backends = ["transformers"] + _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) + requires_backends(self, ["torch", "transformers"]) class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): - _backends = ["transformers"] + _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) + requires_backends(self, ["torch", "transformers"]) class StableDiffusionInpaintPipeline(metaclass=DummyObject): - _backends = ["transformers"] + _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) + requires_backends(self, ["torch", "transformers"]) class StableDiffusionPipeline(metaclass=DummyObject): - _backends = ["transformers"] + _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) + requires_backends(self, ["torch", "transformers"]) diff --git a/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py deleted file mode 100644 index 8c2aec21..00000000 --- a/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py +++ /dev/null @@ -1,10 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa -from ..utils import DummyObject, requires_backends - - -class GradTTSPipeline(metaclass=DummyObject): - _backends = ["transformers", "inflect", "unidecode"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers", "inflect", "unidecode"])