Fix ONNX conversion and inference (#1416)

This commit is contained in:
Anton Lozhkov 2022-11-25 14:51:17 +01:00 committed by GitHub
parent d52388f486
commit 86aa747da9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 94 deletions

View File

@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
) )
del pipeline.safety_checker del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker") safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
feature_extractor = pipeline.feature_extractor
else: else:
safety_checker = None safety_checker = None
feature_extractor = None
onnx_pipeline = OnnxStableDiffusionPipeline( onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler, scheduler=pipeline.scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=pipeline.feature_extractor, feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
) )
onnx_pipeline.save_pretrained(output_path) onnx_pipeline.save_pretrained(output_path)

View File

@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae_encoder: OnnxRuntimeModel, vae_encoder: OnnxRuntimeModel,
@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
height: Optional[int] = None, height: Optional[int] = 512,
width: Optional[int] = None, width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs, **kwargs,
): ):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif isinstance(prompt, list):
@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
latents_shape = ( latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
batch_size * num_images_per_prompt,
4,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape: elif latents.shape != latents_shape:

View File

@ -19,7 +19,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae_encoder: OnnxRuntimeModel, vae_encoder: OnnxRuntimeModel,
@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,

View File

@ -19,7 +19,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae_encoder: OnnxRuntimeModel, vae_encoder: OnnxRuntimeModel,
@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
image: PIL.Image.Image, image: PIL.Image.Image,
mask_image: PIL.Image.Image, mask_image: PIL.Image.Image,
height: Optional[int] = None, height: Optional[int] = 512,
width: Optional[int] = None, width: Optional[int] = 512,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`. instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image. The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
) )
num_channels_latents = NUM_LATENT_CHANNELS num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = ( latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
batch_size * num_images_per_prompt,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)