From 1138d63b519e37f0ce04e027b9f4a3261d27c628 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 16 Nov 2022 18:42:21 +0100 Subject: [PATCH] Temporary local test for PIL_INTERPOLATION (#1317) * Temporary local test for PIL_INTERPOLATION * Fix examples too. --- examples/community/imagic_stable_diffusion.py | 22 ++++++++++++++++++- examples/community/lpw_stable_diffusion.py | 22 ++++++++++++++++++- .../community/lpw_stable_diffusion_onnx.py | 21 +++++++++++++++++- .../textual_inversion/textual_inversion.py | 22 ++++++++++++++++++- .../textual_inversion_flax.py | 22 ++++++++++++++++++- 5 files changed, 104 insertions(+), 5 deletions(-) diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py index d6d89283..4aee9531 100644 --- a/examples/community/imagic_stable_diffusion.py +++ b/examples/community/imagic_stable_diffusion.py @@ -17,11 +17,31 @@ from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from diffusers.utils import PIL_INTERPOLATION, logging +from diffusers.utils import logging from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 8c5f5b46..5f0c6b70 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -12,9 +12,29 @@ from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from diffusers.utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging +from diffusers.utils import deprecate, is_accelerate_available, logging from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index 268af775..7c699518 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -10,9 +10,28 @@ from diffusers.onnx_utils import OnnxRuntimeModel from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from diffusers.utils import PIL_INTERPOLATION, logging +from diffusers.utils import logging from transformers import CLIPFeatureExtractor, CLIPTokenizer +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 532ce4a7..a40caec2 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -18,13 +18,33 @@ from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker -from diffusers.utils import PIL_INTERPOLATION from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +import PIL +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = get_logger(__name__) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 008fe812..eaa0c691 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -23,7 +23,6 @@ from diffusers import ( FlaxUNet2DConditionModel, ) from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker -from diffusers.utils import PIL_INTERPOLATION from flax import jax_utils from flax.training import train_state from flax.training.common_utils import shard @@ -34,6 +33,27 @@ from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +import PIL +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = logging.getLogger(__name__)