diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0a0b0b49..0041030c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -58,7 +58,7 @@ else: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 if is_torch_available() and is_transformers_available() and is_onnx_available(): - from .pipelines import StableDiffusionOnnxPipeline + from .pipelines import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline else: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1c31595f..01391b0d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -20,7 +20,7 @@ if is_torch_available() and is_transformers_available(): ) if is_transformers_available() and is_onnx_available(): - from .stable_diffusion import StableDiffusionOnnxPipeline + from .stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 8c07afe5..b1f3240f 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -34,7 +34,7 @@ if is_transformers_available() and is_torch_available(): from .safety_checker import StableDiffusionSafetyChecker if is_transformers_available() and is_onnx_available(): - from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline + from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): import flax diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py similarity index 89% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py rename to src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index acd6446a..91bf5012 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -8,14 +8,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...onnx_utils import OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging +from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) -class StableDiffusionOnnxPipeline(DiffusionPipeline): +class OnnxStableDiffusionPipeline(DiffusionPipeline): vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer @@ -198,3 +198,27 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): + def __init__( + self, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." + deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) + super().__init__( + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py index d099b837..72ca97f5 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -4,6 +4,21 @@ from ..utils import DummyObject, requires_backends +class OnnxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + class StableDiffusionOnnxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 69a45c42..659e6955 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -37,13 +37,13 @@ from diffusers import ( LDMPipeline, LDMTextToImagePipeline, LMSDiscreteScheduler, + OnnxStableDiffusionPipeline, PNDMPipeline, PNDMScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, - StableDiffusionOnnxPipeline, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, @@ -2010,7 +2010,7 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_stable_diffusion_onnx(self): - sd_pipe = StableDiffusionOnnxPipeline.from_pretrained( + sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" ) @@ -2214,7 +2214,7 @@ class PipelineTesterMixin(unittest.TestCase): test_callback_fn.has_been_called = False - pipe = StableDiffusionOnnxPipeline.from_pretrained( + pipe = OnnxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" ) pipe.set_progress_bar_config(disable=None)