From 970e30606c2944e3286f56e8eb6d3dc6d1eb85f7 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 6 Oct 2022 18:35:40 +0200 Subject: [PATCH] Revert "[v0.4.0] Temporarily remove Flax modules from the public API (#755)" This reverts commit 2e209c30cf6f2ba42001d0629dc6b7ce354b9a9d. --- docs/source/api/models.mdx | 18 +++++++++++++ docs/source/api/schedulers.mdx | 2 +- setup.py | 12 +++++++-- src/diffusers/__init__.py | 23 ++++++++++++++++ src/diffusers/dependency_versions_table.py | 3 +++ src/diffusers/models/__init__.py | 6 ++++- src/diffusers/pipelines/__init__.py | 3 +++ .../pipelines/stable_diffusion/__init__.py | 26 ++++++++++++++++++- src/diffusers/schedulers/__init__.py | 13 +++++++++- 9 files changed, 100 insertions(+), 6 deletions(-) diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 525548e7..c92fdccb 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## AutoencoderKL [[autodoc]] AutoencoderKL + +## FlaxModelMixin +[[autodoc]] FlaxModelMixin + +## FlaxUNet2DConditionOutput +[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput + +## FlaxUNet2DConditionModel +[[autodoc]] FlaxUNet2DConditionModel + +## FlaxDecoderOutput +[[autodoc]] models.vae_flax.FlaxDecoderOutput + +## FlaxAutoencoderKLOutput +[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput + +## FlaxAutoencoderKL +[[autodoc]] FlaxAutoencoderKL diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index d2c4c4d4..12a6b5c5 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher To this end, the design of schedulers is such that: - Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality. -- Schedulers are currently by default in PyTorch. +- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists). ## API diff --git a/setup.py b/setup.py index 6488f4b3..874f3f0c 100644 --- a/setup.py +++ b/setup.py @@ -84,10 +84,13 @@ _deps = [ "datasets", "filelock", "flake8>=3.8.3", + "flax>=0.4.1", "hf-doc-builder>=0.3.0", "huggingface-hub>=0.10.0", "importlib_metadata", "isort>=5.5.4", + "jax>=0.2.8,!=0.3.2,<=0.3.6", + "jaxlib>=0.1.65,<=0.3.6", "modelcards>=0.1.4", "numpy", "onnxruntime", @@ -185,9 +188,15 @@ extras["test"] = deps_list( "torchvision", "transformers" ) +extras["torch"] = deps_list("torch") + +if os.name == "nt": # windows + extras["flax"] = [] # jax is not supported on windows +else: + extras["flax"] = deps_list("jax", "jaxlib", "flax") extras["dev"] = ( - extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] ) install_requires = [ @@ -198,7 +207,6 @@ install_requires = [ deps["regex"], deps["requests"], deps["Pillow"], - deps["torch"] ] setup( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 817df835..ac9ccceb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,5 @@ from .utils import ( + is_flax_available, is_inflect_available, is_onnx_available, is_scipy_available, @@ -60,3 +61,25 @@ if is_torch_available() and is_transformers_available() and is_onnx_available(): from .pipelines import StableDiffusionOnnxPipeline else: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 + +if is_flax_available(): + from .modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipeline_flax_utils import FlaxDiffusionPipeline + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxSchedulerMixin, + FlaxScoreSdeVeScheduler, + ) +else: + from .utils.dummy_flax_objects import * # noqa F403 + +if is_flax_available() and is_transformers_available(): + from .pipelines import FlaxStableDiffusionPipeline +else: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 7ea7a66f..8b10d70a 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -8,10 +8,13 @@ deps = { "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", + "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", "huggingface-hub": "huggingface-hub>=0.10.0", "importlib_metadata": "importlib_metadata", "isort": "isort>=5.5.4", + "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", + "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", "onnxruntime": "onnxruntime", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index aff1ec1c..1242ad6f 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_torch_available +from ..utils import is_flax_available, is_torch_available if is_torch_available(): from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3f3df460..1c31595f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -21,3 +21,6 @@ if is_torch_available() and is_transformers_available(): if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline + +if is_transformers_available() and is_flax_available(): + from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 289c6e1a..615fa404 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,7 @@ import numpy as np import PIL from PIL import Image -from ...utils import BaseOutput, is_onnx_available, is_torch_available, is_transformers_available +from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available @dataclass @@ -35,3 +35,27 @@ if is_transformers_available() and is_torch_available(): if is_transformers_available() and is_onnx_available(): from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline + +if is_transformers_available() and is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] + + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a3c23d0f..a906c39e 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -from ..utils import is_scipy_available, is_torch_available +from ..utils import is_flax_available, is_scipy_available, is_torch_available if is_torch_available(): @@ -27,6 +27,17 @@ if is_torch_available(): else: from ..utils.dummy_pt_objects import * # noqa F403 +if is_flax_available(): + from .scheduling_ddim_flax import FlaxDDIMScheduler + from .scheduling_ddpm_flax import FlaxDDPMScheduler + from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler + from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler + from .scheduling_pndm_flax import FlaxPNDMScheduler + from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler + from .scheduling_utils_flax import FlaxSchedulerMixin +else: + from ..utils.dummy_flax_objects import * # noqa F403 + if is_scipy_available() and is_torch_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler