Rename StableDiffusionOnnxPipeline -> OnnxStableDiffusionPipeline (#887)
Rename and deprecate
This commit is contained in:
parent
100e094cc9
commit
728a3f3ec1
|
@ -58,7 +58,7 @@ else:
|
|||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
|
||||
if is_torch_available() and is_transformers_available() and is_onnx_available():
|
||||
from .pipelines import StableDiffusionOnnxPipeline
|
||||
from .pipelines import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
|
||||
else:
|
||||
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ if is_torch_available() and is_transformers_available():
|
|||
)
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .stable_diffusion import StableDiffusionOnnxPipeline
|
||||
from .stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
from .stable_diffusion import FlaxStableDiffusionPipeline
|
||||
|
|
|
@ -34,7 +34,7 @@ if is_transformers_available() and is_torch_available():
|
|||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
|
||||
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
import flax
|
||||
|
|
|
@ -8,14 +8,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
|||
from ...onnx_utils import OnnxRuntimeModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import logging
|
||||
from ...utils import deprecate, logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
||||
class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
text_encoder: OnnxRuntimeModel
|
||||
tokenizer: CLIPTokenizer
|
||||
|
@ -198,3 +198,27 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
|||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
|
||||
deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
|
||||
super().__init__(
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
|
@ -4,6 +4,21 @@
|
|||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class OnnxStableDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "onnx"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "onnx"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "onnx"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "onnx"])
|
||||
|
||||
|
||||
class StableDiffusionOnnxPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "onnx"]
|
||||
|
||||
|
|
|
@ -37,13 +37,13 @@ from diffusers import (
|
|||
LDMPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
OnnxStableDiffusionPipeline,
|
||||
PNDMPipeline,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVePipeline,
|
||||
ScoreSdeVeScheduler,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
|
@ -2010,7 +2010,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx(self):
|
||||
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
|
||||
|
@ -2214,7 +2214,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionOnnxPipeline.from_pretrained(
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
|
Loading…
Reference in New Issue