diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 0b2fa15d..b3f3a911 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,14 @@ import numpy as np import PIL from PIL import Image -from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available +from ...utils import ( + BaseOutput, + is_flax_available, + is_onnx_available, + is_torch_available, + is_transformers_available, + is_transformers_version, +) @dataclass @@ -30,12 +37,16 @@ class StableDiffusionPipelineOutput(BaseOutput): if is_transformers_available() and is_torch_available(): from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline - from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .safety_checker import StableDiffusionSafetyChecker +if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0"): + from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline +else: + from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline + if is_transformers_available() and is_onnx_available(): from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline diff --git a/src/diffusers/pipelines/versatile_diffusion/__init__.py b/src/diffusers/pipelines/versatile_diffusion/__init__.py index 65bc1b72..ebe4343d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/__init__.py +++ b/src/diffusers/pipelines/versatile_diffusion/__init__.py @@ -1,9 +1,16 @@ -from ...utils import is_torch_available, is_transformers_available +from ...utils import is_torch_available, is_transformers_available, is_transformers_version -if is_transformers_available() and is_torch_available(): +if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0"): from .modeling_text_unet import UNetFlatConditionModel from .pipeline_versatile_diffusion import VersatileDiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline +else: + from ...utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 909d878e..e86f3b80 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -33,6 +33,7 @@ from .import_utils import ( is_torch_available, is_torch_version, is_transformers_available, + is_transformers_version, is_unidecode_available, requires_backends, ) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 005cbb61..ddbd9350 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -303,6 +303,17 @@ def requires_backends(obj, backends): if failed: raise ImportError("".join(failed)) + if name in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + ] and is_transformers_version("<", "4.25.0"): + raise ImportError( + f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install" + " git+https://github.com/huggingface/transformers \n```" + ) + class DummyObject(type): """ @@ -347,3 +358,17 @@ def is_torch_version(operation: str, version: str): A string version of PyTorch """ return compare_versions(parse(_torch_version), operation, version) + + +def is_transformers_version(operation: str, version: str): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version)