Temporary local test for PIL_INTERPOLATION (#1317)
* Temporary local test for PIL_INTERPOLATION * Fix examples too.
This commit is contained in:
parent
afdd7bb635
commit
1138d63b51
|
@ -17,11 +17,31 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
|||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
from diffusers.utils import logging
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
|
|
@ -12,9 +12,29 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
|||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging
|
||||
from diffusers.utils import deprecate, is_accelerate_available, logging
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
|
|
@ -10,9 +10,28 @@ from diffusers.onnx_utils import OnnxRuntimeModel
|
|||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
from diffusers.utils import logging
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
|
|
@ -18,13 +18,33 @@ from accelerate.utils import set_seed
|
|||
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
import PIL
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ from diffusers import (
|
|||
FlaxUNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
from flax import jax_utils
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
|
@ -34,6 +33,27 @@ from tqdm.auto import tqdm
|
|||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
||||
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
import PIL
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue