Add image_processor (#2617)

* add image_processor

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
YiYi Xu 2023-03-15 07:55:49 -10:00 committed by GitHub
parent c0b4d72095
commit e52cd55615
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 447 additions and 45 deletions

View File

@ -0,0 +1,177 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Union
import numpy as np
import PIL
import torch
from PIL import Image
from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION
class VaeImageProcessor(ConfigMixin):
"""
Image Processor for VAE
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1]
"""
config_name = CONFIG_NAME
@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
):
super().__init__()
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
@staticmethod
def numpy_to_pt(images):
"""
Convert a numpy image to a pytorch tensor
"""
if images.ndim == 3:
images = images[..., None]
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
return images
@staticmethod
def pt_to_numpy(images):
"""
Convert a numpy image to a pytorch tensor
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images
@staticmethod
def normalize(images):
"""
Normalize an image array to [-1,1]
"""
return 2.0 * images - 1.0
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
"""
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
"""
w, h = images.size
w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
return images
def preprocess(
self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
) -> torch.Tensor:
"""
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
if isinstance(image, supported_formats):
image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
)
if isinstance(image[0], PIL.Image.Image):
if self.do_resize:
image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np
image = self.numpy_to_pt(image) # to pt
elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
# expected range [0,1], normalize to [-1,1]
do_normalize = self.do_normalize
if image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
FutureWarning,
)
do_normalize = False
if do_normalize:
image = self.normalize(image)
return image
def postprocess(
self,
image,
output_type: str = "pil",
):
if isinstance(image, torch.Tensor) and output_type == "pt":
return image
image = self.pt_to_numpy(image)
if output_type == "np":
return image
elif output_type == "pil":
return self.numpy_to_pil(image)
else:
raise ValueError(f"Unsupported output_type {output_type}.")

View File

@ -24,6 +24,7 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
@ -192,7 +193,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@ -203,7 +203,11 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
@ -415,21 +419,17 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return prompt_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
@ -663,7 +663,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
)
# 4. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@ -703,15 +703,26 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
if output_type == "latent":
image = latents
has_nsfw_concept = None
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

View File

@ -22,6 +22,7 @@ from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@ -119,7 +120,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"""
_optional_components = ["safety_checker", "feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
self,
vae: AutoencoderKL,
@ -196,7 +196,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@ -207,7 +206,11 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
@ -422,24 +425,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@ -674,7 +671,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
)
# 4. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@ -714,15 +711,26 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
if output_type == "latent":
image = latents
has_nsfw_concept = None
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

View File

@ -21,7 +21,13 @@ import numpy as np
import torch
from transformers import XLMRobertaTokenizer
from diffusers import AltDiffusionImg2ImgPipeline, AutoencoderKL, PNDMScheduler, UNet2DConditionModel
from diffusers import (
AltDiffusionImg2ImgPipeline,
AutoencoderKL,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
RobertaSeriesConfig,
RobertaSeriesModelWithTransformation,
@ -128,6 +134,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)
alt_pipe = alt_pipe.to(device)
alt_pipe.set_progress_bar_config(disable=None)
@ -191,6 +198,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)
alt_pipe = alt_pipe.to(torch_device)
alt_pipe.set_progress_bar_config(disable=None)

View File

@ -30,6 +30,7 @@ from diffusers import (
StableDiffusionImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
@ -94,19 +95,33 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
}
return components
def get_dummy_inputs(self, device, seed=0):
def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
if input_image_type == "pt":
input_image = image
elif input_image_type == "np":
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
elif input_image_type == "pil":
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
input_image = VaeImageProcessor.numpy_to_pil(input_image)
else:
raise ValueError(f"unsupported input_image_type {input_image_type}.")
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"unsupported output_type {output_type}")
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"image": input_image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
"output_type": output_type,
}
return inputs
@ -114,6 +129,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
@ -130,6 +146,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
@ -148,6 +165,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
@ -169,6 +187,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
@ -197,6 +216,36 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
def test_attention_slicing_forward_pass(self):
return super().test_attention_slicing_forward_pass()
@skip_mps
def test_pt_np_pil_outputs_equivalent(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0]
output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0]
output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0]
assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4
assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4
@skip_mps
def test_image_types_consistent(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0]
output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0]
output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0]
assert np.abs(output_pt - output_np).max() <= 1e-4
assert np.abs(output_pil - output_np).max() <= 1e-2
@slow
@require_torch_gpu
@ -219,7 +268,7 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "numpy",
"output_type": "np",
}
return inputs
@ -426,7 +475,7 @@ class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
"num_inference_steps": 50,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "numpy",
"output_type": "np",
}
return inputs

View File

@ -0,0 +1,149 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import PIL
import torch
from diffusers.image_processor import VaeImageProcessor
class ImageProcessorTest(unittest.TestCase):
@property
def dummy_sample(self):
batch_size = 1
num_channels = 3
height = 8
width = 8
sample = torch.rand((batch_size, num_channels, height, width))
return sample
def to_np(self, image):
if isinstance(image[0], PIL.Image.Image):
return np.stack([np.array(i) for i in image], axis=0)
elif isinstance(image, torch.Tensor):
return image.cpu().numpy().transpose(0, 2, 3, 1)
return image
def test_vae_image_processor_pt(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
input_pt = self.dummy_sample
input_np = self.to_np(input_pt)
for output_type in ["pt", "np", "pil"]:
out = image_processor.postprocess(
image_processor.preprocess(input_pt),
output_type=output_type,
)
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
assert (
np.abs(in_np - out_np).max() < 1e-6
), f"decoded output does not match input for output_type {output_type}"
def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
for output_type in ["pt", "np", "pil"]:
out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type)
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
assert (
np.abs(in_np - out_np).max() < 1e-6
), f"decoded output does not match input for output_type {output_type}"
def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
input_pil = image_processor.numpy_to_pil(input_np)
for output_type in ["pt", "np", "pil"]:
out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type)
for i, o in zip(input_pil, out):
in_np = np.array(i)
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
assert (
np.abs(in_np - out_np).max() < 1e-6
), f"decoded output does not match input for output_type {output_type}"
def test_preprocess_input_3d(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
input_pt_4d = self.dummy_sample
input_pt_3d = input_pt_4d.squeeze(0)
out_pt_4d = image_processor.postprocess(
image_processor.preprocess(input_pt_4d),
output_type="np",
)
out_pt_3d = image_processor.postprocess(
image_processor.preprocess(input_pt_3d),
output_type="np",
)
input_np_4d = self.to_np(self.dummy_sample)
input_np_3d = input_np_4d.squeeze(0)
out_np_4d = image_processor.postprocess(
image_processor.preprocess(input_np_4d),
output_type="np",
)
out_np_3d = image_processor.postprocess(
image_processor.preprocess(input_np_3d),
output_type="np",
)
assert np.abs(out_pt_4d - out_pt_3d).max() < 1e-6
assert np.abs(out_np_4d - out_np_3d).max() < 1e-6
def test_preprocess_input_list(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
input_pt_4d = self.dummy_sample
input_pt_list = list(input_pt_4d)
out_pt_4d = image_processor.postprocess(
image_processor.preprocess(input_pt_4d),
output_type="np",
)
out_pt_list = image_processor.postprocess(
image_processor.preprocess(input_pt_list),
output_type="np",
)
input_np_4d = self.to_np(self.dummy_sample)
list(input_np_4d)
out_np_4d = image_processor.postprocess(
image_processor.preprocess(input_pt_4d),
output_type="np",
)
out_np_list = image_processor.postprocess(
image_processor.preprocess(input_pt_list),
output_type="np",
)
assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6
assert np.abs(out_np_4d - out_np_list).max() < 1e-6