Add image_processor (#2617)
* add image_processor --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
c0b4d72095
commit
e52cd55615
|
@ -0,0 +1,177 @@
|
||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from .utils import CONFIG_NAME, PIL_INTERPOLATION
|
||||||
|
|
||||||
|
|
||||||
|
class VaeImageProcessor(ConfigMixin):
|
||||||
|
"""
|
||||||
|
Image Processor for VAE
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
||||||
|
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
||||||
|
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
|
||||||
|
factor.
|
||||||
|
resample (`str`, *optional*, defaults to `lanczos`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image to [-1,1]
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_name = CONFIG_NAME
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
vae_scale_factor: int = 8,
|
||||||
|
resample: str = "lanczos",
|
||||||
|
do_normalize: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def numpy_to_pil(images):
|
||||||
|
"""
|
||||||
|
Convert a numpy image or a batch of images to a PIL image.
|
||||||
|
"""
|
||||||
|
if images.ndim == 3:
|
||||||
|
images = images[None, ...]
|
||||||
|
images = (images * 255).round().astype("uint8")
|
||||||
|
if images.shape[-1] == 1:
|
||||||
|
# special case for grayscale (single channel) images
|
||||||
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
||||||
|
else:
|
||||||
|
pil_images = [Image.fromarray(image) for image in images]
|
||||||
|
|
||||||
|
return pil_images
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def numpy_to_pt(images):
|
||||||
|
"""
|
||||||
|
Convert a numpy image to a pytorch tensor
|
||||||
|
"""
|
||||||
|
if images.ndim == 3:
|
||||||
|
images = images[..., None]
|
||||||
|
|
||||||
|
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
||||||
|
return images
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pt_to_numpy(images):
|
||||||
|
"""
|
||||||
|
Convert a numpy image to a pytorch tensor
|
||||||
|
"""
|
||||||
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
return images
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize(images):
|
||||||
|
"""
|
||||||
|
Normalize an image array to [-1,1]
|
||||||
|
"""
|
||||||
|
return 2.0 * images - 1.0
|
||||||
|
|
||||||
|
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
|
||||||
|
"""
|
||||||
|
w, h = images.size
|
||||||
|
w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor
|
||||||
|
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
|
||||||
|
return images
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
|
||||||
|
"""
|
||||||
|
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||||
|
if isinstance(image, supported_formats):
|
||||||
|
image = [image]
|
||||||
|
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(image[0], PIL.Image.Image):
|
||||||
|
if self.do_resize:
|
||||||
|
image = [self.resize(i) for i in image]
|
||||||
|
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
|
||||||
|
image = np.stack(image, axis=0) # to np
|
||||||
|
image = self.numpy_to_pt(image) # to pt
|
||||||
|
|
||||||
|
elif isinstance(image[0], np.ndarray):
|
||||||
|
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
||||||
|
image = self.numpy_to_pt(image)
|
||||||
|
_, _, height, width = image.shape
|
||||||
|
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}"
|
||||||
|
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(image[0], torch.Tensor):
|
||||||
|
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||||
|
_, _, height, width = image.shape
|
||||||
|
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}"
|
||||||
|
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
||||||
|
)
|
||||||
|
|
||||||
|
# expected range [0,1], normalize to [-1,1]
|
||||||
|
do_normalize = self.do_normalize
|
||||||
|
if image.min() < 0:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
||||||
|
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
do_normalize = False
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
image = self.normalize(image)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def postprocess(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
output_type: str = "pil",
|
||||||
|
):
|
||||||
|
if isinstance(image, torch.Tensor) and output_type == "pt":
|
||||||
|
return image
|
||||||
|
|
||||||
|
image = self.pt_to_numpy(image)
|
||||||
|
|
||||||
|
if output_type == "np":
|
||||||
|
return image
|
||||||
|
elif output_type == "pil":
|
||||||
|
return self.numpy_to_pil(image)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported output_type {output_type}.")
|
|
@ -24,6 +24,7 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...image_processor import VaeImageProcessor
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import KarrasDiffusionSchedulers
|
from ...schedulers import KarrasDiffusionSchedulers
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
|
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
|
||||||
|
@ -192,7 +193,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
new_config = dict(unet.config)
|
new_config = dict(unet.config)
|
||||||
new_config["sample_size"] = 64
|
new_config["sample_size"] = 64
|
||||||
unet._internal_dict = FrozenDict(new_config)
|
unet._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -203,7 +203,11 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
|
||||||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||||
|
self.register_to_config(
|
||||||
|
requires_safety_checker=requires_safety_checker,
|
||||||
|
)
|
||||||
|
|
||||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||||
r"""
|
r"""
|
||||||
|
@ -415,21 +419,17 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
def run_safety_checker(self, image, device, dtype):
|
def run_safety_checker(self, image, device, dtype):
|
||||||
if self.safety_checker is not None:
|
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||||
image, has_nsfw_concept = self.safety_checker(
|
image, has_nsfw_concept = self.safety_checker(
|
||||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
has_nsfw_concept = None
|
|
||||||
return image, has_nsfw_concept
|
return image, has_nsfw_concept
|
||||||
|
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents):
|
||||||
latents = 1 / self.vae.config.scaling_factor * latents
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
image = self.vae.decode(latents).sample
|
image = self.vae.decode(latents).sample
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
|
||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def prepare_extra_step_kwargs(self, generator, eta):
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
|
@ -663,7 +663,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Preprocess image
|
# 4. Preprocess image
|
||||||
image = preprocess(image)
|
image = self.image_processor.preprocess(image)
|
||||||
|
|
||||||
# 5. set timesteps
|
# 5. set timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
@ -703,15 +703,26 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
# 9. Post-processing
|
if output_type not in ["latent", "pt", "np", "pil"]:
|
||||||
|
deprecation_message = (
|
||||||
|
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
|
||||||
|
"`pil`, `np`, `pt`, `latent`"
|
||||||
|
)
|
||||||
|
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
output_type = "np"
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
image = self.decode_latents(latents)
|
image = self.decode_latents(latents)
|
||||||
|
|
||||||
# 10. Run safety checker
|
if self.safety_checker is not None:
|
||||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = False
|
||||||
|
|
||||||
# 11. Convert to PIL
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
if output_type == "pil":
|
|
||||||
image = self.numpy_to_pil(image)
|
|
||||||
|
|
||||||
# Offload last model to CPU
|
# Offload last model to CPU
|
||||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||||
|
|
|
@ -22,6 +22,7 @@ from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...image_processor import VaeImageProcessor
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import KarrasDiffusionSchedulers
|
from ...schedulers import KarrasDiffusionSchedulers
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
|
@ -119,7 +120,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
"""
|
"""
|
||||||
_optional_components = ["safety_checker", "feature_extractor"]
|
_optional_components = ["safety_checker", "feature_extractor"]
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vae: AutoencoderKL,
|
vae: AutoencoderKL,
|
||||||
|
@ -196,7 +196,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
new_config = dict(unet.config)
|
new_config = dict(unet.config)
|
||||||
new_config["sample_size"] = 64
|
new_config["sample_size"] = 64
|
||||||
unet._internal_dict = FrozenDict(new_config)
|
unet._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -207,7 +206,11 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
|
||||||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||||
|
self.register_to_config(
|
||||||
|
requires_safety_checker=requires_safety_checker,
|
||||||
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||||
|
@ -422,24 +425,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
|
||||||
def run_safety_checker(self, image, device, dtype):
|
def run_safety_checker(self, image, device, dtype):
|
||||||
if self.safety_checker is not None:
|
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||||
image, has_nsfw_concept = self.safety_checker(
|
image, has_nsfw_concept = self.safety_checker(
|
||||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
has_nsfw_concept = None
|
|
||||||
return image, has_nsfw_concept
|
return image, has_nsfw_concept
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents):
|
||||||
latents = 1 / self.vae.config.scaling_factor * latents
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
image = self.vae.decode(latents).sample
|
image = self.vae.decode(latents).sample
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
|
||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||||
|
@ -674,7 +671,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Preprocess image
|
# 4. Preprocess image
|
||||||
image = preprocess(image)
|
image = self.image_processor.preprocess(image)
|
||||||
|
|
||||||
# 5. set timesteps
|
# 5. set timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
@ -714,15 +711,26 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
# 9. Post-processing
|
if output_type not in ["latent", "pt", "np", "pil"]:
|
||||||
|
deprecation_message = (
|
||||||
|
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
|
||||||
|
"`pil`, `np`, `pt`, `latent`"
|
||||||
|
)
|
||||||
|
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
output_type = "np"
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
image = self.decode_latents(latents)
|
image = self.decode_latents(latents)
|
||||||
|
|
||||||
# 10. Run safety checker
|
if self.safety_checker is not None:
|
||||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = False
|
||||||
|
|
||||||
# 11. Convert to PIL
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
if output_type == "pil":
|
|
||||||
image = self.numpy_to_pil(image)
|
|
||||||
|
|
||||||
# Offload last model to CPU
|
# Offload last model to CPU
|
||||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||||
|
|
|
@ -21,7 +21,13 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import XLMRobertaTokenizer
|
from transformers import XLMRobertaTokenizer
|
||||||
|
|
||||||
from diffusers import AltDiffusionImg2ImgPipeline, AutoencoderKL, PNDMScheduler, UNet2DConditionModel
|
from diffusers import (
|
||||||
|
AltDiffusionImg2ImgPipeline,
|
||||||
|
AutoencoderKL,
|
||||||
|
PNDMScheduler,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
)
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
|
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
|
||||||
RobertaSeriesConfig,
|
RobertaSeriesConfig,
|
||||||
RobertaSeriesModelWithTransformation,
|
RobertaSeriesModelWithTransformation,
|
||||||
|
@ -128,6 +134,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=self.dummy_extractor,
|
feature_extractor=self.dummy_extractor,
|
||||||
)
|
)
|
||||||
|
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)
|
||||||
alt_pipe = alt_pipe.to(device)
|
alt_pipe = alt_pipe.to(device)
|
||||||
alt_pipe.set_progress_bar_config(disable=None)
|
alt_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -191,6 +198,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=self.dummy_extractor,
|
feature_extractor=self.dummy_extractor,
|
||||||
)
|
)
|
||||||
|
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)
|
||||||
alt_pipe = alt_pipe.to(torch_device)
|
alt_pipe = alt_pipe.to(torch_device)
|
||||||
alt_pipe.set_progress_bar_config(disable=None)
|
alt_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ from diffusers import (
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
||||||
|
|
||||||
|
@ -94,19 +95,33 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
}
|
}
|
||||||
return components
|
return components
|
||||||
|
|
||||||
def get_dummy_inputs(self, device, seed=0):
|
def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"):
|
||||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||||
if str(device).startswith("mps"):
|
if str(device).startswith("mps"):
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
else:
|
else:
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
if input_image_type == "pt":
|
||||||
|
input_image = image
|
||||||
|
elif input_image_type == "np":
|
||||||
|
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
elif input_image_type == "pil":
|
||||||
|
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
input_image = VaeImageProcessor.numpy_to_pil(input_image)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unsupported input_image_type {input_image_type}.")
|
||||||
|
|
||||||
|
if output_type not in ["pt", "np", "pil"]:
|
||||||
|
raise ValueError(f"unsupported output_type {output_type}")
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt": "A painting of a squirrel eating a burger",
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
"image": image,
|
"image": input_image,
|
||||||
"generator": generator,
|
"generator": generator,
|
||||||
"num_inference_steps": 2,
|
"num_inference_steps": 2,
|
||||||
"guidance_scale": 6.0,
|
"guidance_scale": 6.0,
|
||||||
"output_type": "numpy",
|
"output_type": output_type,
|
||||||
}
|
}
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
@ -114,6 +129,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
components = self.get_dummy_components()
|
components = self.get_dummy_components()
|
||||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||||
|
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
|
||||||
sd_pipe = sd_pipe.to(device)
|
sd_pipe = sd_pipe.to(device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -130,6 +146,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
components = self.get_dummy_components()
|
components = self.get_dummy_components()
|
||||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||||
|
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
|
||||||
sd_pipe = sd_pipe.to(device)
|
sd_pipe = sd_pipe.to(device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -148,6 +165,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
components = self.get_dummy_components()
|
components = self.get_dummy_components()
|
||||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||||
|
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
|
||||||
sd_pipe = sd_pipe.to(device)
|
sd_pipe = sd_pipe.to(device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -169,6 +187,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||||
)
|
)
|
||||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||||
|
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
|
||||||
sd_pipe = sd_pipe.to(device)
|
sd_pipe = sd_pipe.to(device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -197,6 +216,36 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
def test_attention_slicing_forward_pass(self):
|
def test_attention_slicing_forward_pass(self):
|
||||||
return super().test_attention_slicing_forward_pass()
|
return super().test_attention_slicing_forward_pass()
|
||||||
|
|
||||||
|
@skip_mps
|
||||||
|
def test_pt_np_pil_outputs_equivalent(self):
|
||||||
|
device = "cpu"
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0]
|
||||||
|
output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0]
|
||||||
|
output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0]
|
||||||
|
|
||||||
|
assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4
|
||||||
|
assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4
|
||||||
|
|
||||||
|
@skip_mps
|
||||||
|
def test_image_types_consistent(self):
|
||||||
|
device = "cpu"
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0]
|
||||||
|
output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0]
|
||||||
|
output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0]
|
||||||
|
|
||||||
|
assert np.abs(output_pt - output_np).max() <= 1e-4
|
||||||
|
assert np.abs(output_pil - output_np).max() <= 1e-2
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
@ -219,7 +268,7 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
"num_inference_steps": 3,
|
"num_inference_steps": 3,
|
||||||
"strength": 0.75,
|
"strength": 0.75,
|
||||||
"guidance_scale": 7.5,
|
"guidance_scale": 7.5,
|
||||||
"output_type": "numpy",
|
"output_type": "np",
|
||||||
}
|
}
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
@ -426,7 +475,7 @@ class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||||
"num_inference_steps": 50,
|
"num_inference_steps": 50,
|
||||||
"strength": 0.75,
|
"strength": 0.75,
|
||||||
"guidance_scale": 7.5,
|
"guidance_scale": 7.5,
|
||||||
"output_type": "numpy",
|
"output_type": "np",
|
||||||
}
|
}
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProcessorTest(unittest.TestCase):
|
||||||
|
@property
|
||||||
|
def dummy_sample(self):
|
||||||
|
batch_size = 1
|
||||||
|
num_channels = 3
|
||||||
|
height = 8
|
||||||
|
width = 8
|
||||||
|
|
||||||
|
sample = torch.rand((batch_size, num_channels, height, width))
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def to_np(self, image):
|
||||||
|
if isinstance(image[0], PIL.Image.Image):
|
||||||
|
return np.stack([np.array(i) for i in image], axis=0)
|
||||||
|
elif isinstance(image, torch.Tensor):
|
||||||
|
return image.cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def test_vae_image_processor_pt(self):
|
||||||
|
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||||
|
|
||||||
|
input_pt = self.dummy_sample
|
||||||
|
input_np = self.to_np(input_pt)
|
||||||
|
|
||||||
|
for output_type in ["pt", "np", "pil"]:
|
||||||
|
out = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_pt),
|
||||||
|
output_type=output_type,
|
||||||
|
)
|
||||||
|
out_np = self.to_np(out)
|
||||||
|
in_np = (input_np * 255).round() if output_type == "pil" else input_np
|
||||||
|
assert (
|
||||||
|
np.abs(in_np - out_np).max() < 1e-6
|
||||||
|
), f"decoded output does not match input for output_type {output_type}"
|
||||||
|
|
||||||
|
def test_vae_image_processor_np(self):
|
||||||
|
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||||
|
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
|
for output_type in ["pt", "np", "pil"]:
|
||||||
|
out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type)
|
||||||
|
|
||||||
|
out_np = self.to_np(out)
|
||||||
|
in_np = (input_np * 255).round() if output_type == "pil" else input_np
|
||||||
|
assert (
|
||||||
|
np.abs(in_np - out_np).max() < 1e-6
|
||||||
|
), f"decoded output does not match input for output_type {output_type}"
|
||||||
|
|
||||||
|
def test_vae_image_processor_pil(self):
|
||||||
|
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||||
|
|
||||||
|
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
input_pil = image_processor.numpy_to_pil(input_np)
|
||||||
|
|
||||||
|
for output_type in ["pt", "np", "pil"]:
|
||||||
|
out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type)
|
||||||
|
for i, o in zip(input_pil, out):
|
||||||
|
in_np = np.array(i)
|
||||||
|
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
|
||||||
|
assert (
|
||||||
|
np.abs(in_np - out_np).max() < 1e-6
|
||||||
|
), f"decoded output does not match input for output_type {output_type}"
|
||||||
|
|
||||||
|
def test_preprocess_input_3d(self):
|
||||||
|
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||||
|
|
||||||
|
input_pt_4d = self.dummy_sample
|
||||||
|
input_pt_3d = input_pt_4d.squeeze(0)
|
||||||
|
|
||||||
|
out_pt_4d = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_pt_4d),
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
out_pt_3d = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_pt_3d),
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_np_4d = self.to_np(self.dummy_sample)
|
||||||
|
input_np_3d = input_np_4d.squeeze(0)
|
||||||
|
|
||||||
|
out_np_4d = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_np_4d),
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
out_np_3d = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_np_3d),
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(out_pt_4d - out_pt_3d).max() < 1e-6
|
||||||
|
assert np.abs(out_np_4d - out_np_3d).max() < 1e-6
|
||||||
|
|
||||||
|
def test_preprocess_input_list(self):
|
||||||
|
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||||
|
|
||||||
|
input_pt_4d = self.dummy_sample
|
||||||
|
input_pt_list = list(input_pt_4d)
|
||||||
|
|
||||||
|
out_pt_4d = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_pt_4d),
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
out_pt_list = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_pt_list),
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_np_4d = self.to_np(self.dummy_sample)
|
||||||
|
list(input_np_4d)
|
||||||
|
|
||||||
|
out_np_4d = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_pt_4d),
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
out_np_list = image_processor.postprocess(
|
||||||
|
image_processor.preprocess(input_pt_list),
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6
|
||||||
|
assert np.abs(out_np_4d - out_np_list).max() < 1e-6
|
Loading…
Reference in New Issue