Fix ONNX conversion and inference (#1416)
This commit is contained in:
parent
d52388f486
commit
86aa747da9
|
@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
|||
)
|
||||
del pipeline.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
|
||||
feature_extractor = pipeline.feature_extractor
|
||||
else:
|
||||
safety_checker = None
|
||||
feature_extractor = None
|
||||
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
||||
|
@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
|||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=pipeline.feature_extractor,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=safety_checker is not None,
|
||||
)
|
||||
|
||||
onnx_pipeline.save_pretrained(output_path)
|
||||
|
|
|
@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
|
@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
|
@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
|
@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
|
@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
|
@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_dtype = text_embeddings.dtype
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
4,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||
elif latents.shape != latents_shape:
|
||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import PIL
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
|
@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import PIL
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
|
@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
|
@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||
|
@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
prompt: Union[str, List[str]],
|
||||
image: PIL.Image.Image,
|
||||
mask_image: PIL.Image.Image,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
|
@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
|
@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
)
|
||||
|
||||
num_channels_latents = NUM_LATENT_CHANNELS
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||
|
|
Loading…
Reference in New Issue