Fix ONNX conversion and inference (#1416)

This commit is contained in:
Anton Lozhkov 2022-11-25 14:51:17 +01:00 committed by GitHub
parent d52388f486
commit 86aa747da9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 94 deletions

View File

@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
)
del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
feature_extractor = pipeline.feature_extractor
else:
safety_checker = None
feature_extractor = None
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=pipeline.feature_extractor,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
)
onnx_pipeline.save_pretrained(output_path)

View File

@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import numpy as np
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
callback_steps: Optional[int] = 1,
**kwargs,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype
latents_shape = (
batch_size * num_images_per_prompt,
4,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:

View File

@ -19,7 +19,6 @@ import numpy as np
import torch
import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,

View File

@ -19,7 +19,6 @@ import numpy as np
import torch
import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt: Union[str, List[str]],
image: PIL.Image.Image,
mask_image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1
@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
)
num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = (
batch_size * num_images_per_prompt,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)