Rename StableDiffusionOnnxPipeline -> OnnxStableDiffusionPipeline (#887)

Rename and deprecate
This commit is contained in:
Anton Lozhkov 2022-10-18 09:14:30 +02:00 committed by GitHub
parent 100e094cc9
commit 728a3f3ec1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 8 deletions

View File

@ -58,7 +58,7 @@ else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import StableDiffusionOnnxPipeline
from .pipelines import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403

View File

@ -20,7 +20,7 @@ if is_torch_available() and is_transformers_available():
)
if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline
from .stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline

View File

@ -34,7 +34,7 @@ if is_transformers_available() and is_torch_available():
from .safety_checker import StableDiffusionSafetyChecker
if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available():
import flax

View File

@ -8,14 +8,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__)
class StableDiffusionOnnxPipeline(DiffusionPipeline):
class OnnxStableDiffusionPipeline(DiffusionPipeline):
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer
@ -198,3 +198,27 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
def __init__(
self,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
):
deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
super().__init__(
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)

View File

@ -4,6 +4,21 @@
from ..utils import DummyObject, requires_backends
class OnnxStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "onnx"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "onnx"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "onnx"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "onnx"])
class StableDiffusionOnnxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "onnx"]

View File

@ -37,13 +37,13 @@ from diffusers import (
LDMPipeline,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
OnnxStableDiffusionPipeline,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionOnnxPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
UNet2DModel,
@ -2010,7 +2010,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_stable_diffusion_onnx(self):
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
)
@ -2214,7 +2214,7 @@ class PipelineTesterMixin(unittest.TestCase):
test_callback_fn.has_been_called = False
pipe = StableDiffusionOnnxPipeline.from_pretrained(
pipe = OnnxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)