Fix ONNX conversion and inference (#1416)
This commit is contained in:
parent
d52388f486
commit
86aa747da9
|
@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||||
)
|
)
|
||||||
del pipeline.safety_checker
|
del pipeline.safety_checker
|
||||||
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
|
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
|
||||||
|
feature_extractor = pipeline.feature_extractor
|
||||||
else:
|
else:
|
||||||
safety_checker = None
|
safety_checker = None
|
||||||
|
feature_extractor = None
|
||||||
|
|
||||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||||
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
||||||
|
@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||||
scheduler=pipeline.scheduler,
|
scheduler=pipeline.scheduler,
|
||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=pipeline.feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
|
requires_safety_checker=safety_checker is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
onnx_pipeline.save_pretrained(output_path)
|
onnx_pipeline.save_pretrained(output_path)
|
||||||
|
|
|
@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from packaging import version
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||||
safety_checker: OnnxRuntimeModel
|
safety_checker: OnnxRuntimeModel
|
||||||
feature_extractor: CLIPFeatureExtractor
|
feature_extractor: CLIPFeatureExtractor
|
||||||
|
|
||||||
|
_optional_components = ["safety_checker", "feature_extractor"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vae_encoder: OnnxRuntimeModel,
|
vae_encoder: OnnxRuntimeModel,
|
||||||
|
@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
|
||||||
version.parse(unet.config._diffusers_version).base_version
|
|
||||||
) < version.parse("0.9.0.dev0")
|
|
||||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
|
||||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
|
||||||
deprecation_message = (
|
|
||||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
|
||||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
|
||||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
|
||||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
|
||||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
|
||||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
|
||||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
|
||||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
|
||||||
" the `unet/config.json` file"
|
|
||||||
)
|
|
||||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
|
||||||
new_config = dict(unet.config)
|
|
||||||
new_config["sample_size"] = 64
|
|
||||||
unet._internal_dict = FrozenDict(new_config)
|
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae_encoder=vae_encoder,
|
vae_encoder=vae_encoder,
|
||||||
vae_decoder=vae_decoder,
|
vae_decoder=vae_decoder,
|
||||||
|
@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
|
||||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||||
|
|
||||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||||
|
@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = 512,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = 512,
|
||||||
num_inference_steps: Optional[int] = 50,
|
num_inference_steps: Optional[int] = 50,
|
||||||
guidance_scale: Optional[float] = 7.5,
|
guidance_scale: Optional[float] = 7.5,
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||||
callback_steps: Optional[int] = 1,
|
callback_steps: Optional[int] = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# 0. Default height and width to unet
|
|
||||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
||||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
elif isinstance(prompt, list):
|
elif isinstance(prompt, list):
|
||||||
|
@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
# get the initial random noise unless the user supplied it
|
# get the initial random noise unless the user supplied it
|
||||||
latents_dtype = text_embeddings.dtype
|
latents_dtype = text_embeddings.dtype
|
||||||
latents_shape = (
|
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||||
batch_size * num_images_per_prompt,
|
|
||||||
4,
|
|
||||||
height // self.vae_scale_factor,
|
|
||||||
width // self.vae_scale_factor,
|
|
||||||
)
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||||
elif latents.shape != latents_shape:
|
elif latents.shape != latents_shape:
|
||||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from packaging import version
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
safety_checker: OnnxRuntimeModel
|
safety_checker: OnnxRuntimeModel
|
||||||
feature_extractor: CLIPFeatureExtractor
|
feature_extractor: CLIPFeatureExtractor
|
||||||
|
|
||||||
|
_optional_components = ["safety_checker", "feature_extractor"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vae_encoder: OnnxRuntimeModel,
|
vae_encoder: OnnxRuntimeModel,
|
||||||
|
@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
|
||||||
version.parse(unet.config._diffusers_version).base_version
|
|
||||||
) < version.parse("0.9.0.dev0")
|
|
||||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
|
||||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
|
||||||
deprecation_message = (
|
|
||||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
|
||||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
|
||||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
|
||||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
|
||||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
|
||||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
|
||||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
|
||||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
|
||||||
" the `unet/config.json` file"
|
|
||||||
)
|
|
||||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
|
||||||
new_config = dict(unet.config)
|
|
||||||
new_config["sample_size"] = 64
|
|
||||||
unet._internal_dict = FrozenDict(new_config)
|
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae_encoder=vae_encoder,
|
vae_encoder=vae_encoder,
|
||||||
vae_decoder=vae_decoder,
|
vae_decoder=vae_decoder,
|
||||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from packaging import version
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
safety_checker: OnnxRuntimeModel
|
safety_checker: OnnxRuntimeModel
|
||||||
feature_extractor: CLIPFeatureExtractor
|
feature_extractor: CLIPFeatureExtractor
|
||||||
|
|
||||||
|
_optional_components = ["safety_checker", "feature_extractor"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vae_encoder: OnnxRuntimeModel,
|
vae_encoder: OnnxRuntimeModel,
|
||||||
|
@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
|
||||||
version.parse(unet.config._diffusers_version).base_version
|
|
||||||
) < version.parse("0.9.0.dev0")
|
|
||||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
|
||||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
|
||||||
deprecation_message = (
|
|
||||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
|
||||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
|
||||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
|
||||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
|
||||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
|
||||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
|
||||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
|
||||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
|
||||||
" the `unet/config.json` file"
|
|
||||||
)
|
|
||||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
|
||||||
new_config = dict(unet.config)
|
|
||||||
new_config["sample_size"] = 64
|
|
||||||
unet._internal_dict = FrozenDict(new_config)
|
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae_encoder=vae_encoder,
|
vae_encoder=vae_encoder,
|
||||||
vae_decoder=vae_decoder,
|
vae_decoder=vae_decoder,
|
||||||
|
@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
|
||||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||||
|
@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
image: PIL.Image.Image,
|
image: PIL.Image.Image,
|
||||||
mask_image: PIL.Image.Image,
|
mask_image: PIL.Image.Image,
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = 512,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = 512,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
height (`int`, *optional*, defaults to 512):
|
||||||
The height in pixels of the generated image.
|
The height in pixels of the generated image.
|
||||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
width (`int`, *optional*, defaults to 512):
|
||||||
The width in pixels of the generated image.
|
The width in pixels of the generated image.
|
||||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
(nsfw) content, according to the `safety_checker`.
|
(nsfw) content, according to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
# 0. Default height and width to unet
|
|
||||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
||||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
|
|
||||||
num_channels_latents = NUM_LATENT_CHANNELS
|
num_channels_latents = NUM_LATENT_CHANNELS
|
||||||
latents_shape = (
|
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
|
||||||
batch_size * num_images_per_prompt,
|
|
||||||
num_channels_latents,
|
|
||||||
height // self.vae_scale_factor,
|
|
||||||
width // self.vae_scale_factor,
|
|
||||||
)
|
|
||||||
latents_dtype = text_embeddings.dtype
|
latents_dtype = text_embeddings.dtype
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||||
|
|
Loading…
Reference in New Issue