Hotfix for AttributeErrors in OnnxStableDiffusionInpaintPipelineLegacy (#1448)

This commit is contained in:
Anton Lozhkov 2022-11-28 14:18:14 +01:00 committed by GitHub
parent 5755d16868
commit edf22c052e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 24 deletions

View File

@ -5,7 +5,6 @@ import numpy as np
import torch
import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
@ -68,6 +67,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
@ -134,27 +135,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@ -165,7 +145,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@ -372,7 +351,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# preprocess mask
if not isinstance(mask_image, np.ndarray):
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
mask_image = preprocess_mask(mask_image, 8)
mask_image = mask_image.astype(latents_dtype)
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)