StableDiffusionImageVariationPipeline (#1365)

* add StableDiffusionImageVariationPipeline

* add ini init

* use CLIPVisionModelWithProjection

* fix _encode_image

* add copied from

* fix copies

* add doc

* handle tensor in _encode_image

* add tests

* correct model_id

* remove copied from in enable_sequential_cpu_offload

* fix tests

* make slow tests pass

* update slow tests

* use temp model for now

* fix test_stable_diffusion_img_variation_intermediate_state

* fix test_stable_diffusion_img_variation_intermediate_state

* check for torch.Tensor

* quality

* fix name

* fix slow tests

* install transformers from source

* fix install

* fix install

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* input_image -> image

* remove deprication warnings

* fix test_stable_diffusion_img_variation_multiple_images

* make flake happy

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Suraj Patil 2022-11-23 14:36:39 +01:00 committed by GitHub
parent 9e234d8048
commit 0eb507f2af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 890 additions and 0 deletions

View File

@ -60,6 +60,7 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@ -127,6 +128,7 @@ jobs:
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
shell: arch -arch arm64 bash {0}

View File

@ -62,6 +62,7 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@ -131,6 +132,7 @@ jobs:
run: |
python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |

View File

@ -88,3 +88,10 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
- __call__
- enable_attention_slicing
- disable_attention_slicing
## StableDiffusionImageVariationPipeline
[[autodoc]] StableDiffusionImageVariationPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@ -69,6 +69,7 @@ if is_torch_available() and is_transformers_available():
AltDiffusionPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,

View File

@ -19,6 +19,7 @@ if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,

View File

@ -30,6 +30,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
if is_transformers_available() and is_torch_available():
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy

View File

@ -0,0 +1,437 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import torch
import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionImageVariationPipeline(DiffusionPipeline):
r"""
Pipeline to generate variations from an input image using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
image_encoder: CLIPVisionModelWithProjection,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
uncond_embeddings = torch.zeros_like(image_embeddings)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
return image_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, image, height, width, callback_steps):
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor`
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width, callback_steps)
# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -64,6 +64,21 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionImageVariationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@ -0,0 +1,424 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import numpy as np
import torch
from diffusers import (
AutoencoderKL,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImageVariationPipeline,
UNet2DConditionModel,
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image
@property
def dummy_cond_unet(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model
@property
def dummy_vae(self):
torch.manual_seed(0)
model = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
return model
@property
def dummy_image_encoder(self):
torch.manual_seed(0)
config = CLIPVisionConfig(
hidden_size=32,
projection_dim=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
image_size=32,
patch_size=4,
)
return CLIPVisionModelWithProjection(config)
@property
def dummy_extractor(self):
def extract(*args, **kwargs):
class Out:
def __init__(self):
self.pixel_values = torch.ones([0])
def to(self, device):
self.pixel_values.to(device)
return self
return Out()
return extract
def test_stable_diffusion_img_variation_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
image_encoder = self.dummy_image_encoder
init_image = self.dummy_image.to(device)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImageVariationPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
image_encoder=image_encoder,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
init_image,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
)
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
init_image,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
print(image_slice.flatten())
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.4935, 0.4784, 0.4802, 0.5027, 0.4805, 0.5149, 0.5143, 0.4879, 0.4731])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img_variation_multiple_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
image_encoder = self.dummy_image_encoder
init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImageVariationPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
image_encoder=image_encoder,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
init_image,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
)
image = output.images
image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 128, 128, 3)
expected_slice = np.array([0.4939, 0.4627, 0.4831, 0.5710, 0.5387, 0.4428, 0.5230, 0.5545, 0.4586])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
device = "cpu"
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
image_encoder = self.dummy_image_encoder
init_image = self.dummy_image.to(device)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImageVariationPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
image_encoder=image_encoder,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
images = sd_pipe(
init_image,
num_inference_steps=2,
output_type="np",
).images
assert images.shape == (1, 128, 128, 3)
# test num_images_per_prompt=1 (default) for batch of images
batch_size = 2
images = sd_pipe(
init_image.repeat(batch_size, 1, 1, 1),
num_inference_steps=2,
output_type="np",
).images
assert images.shape == (batch_size, 128, 128, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
images = sd_pipe(
init_image,
num_inference_steps=2,
output_type="np",
num_images_per_prompt=num_images_per_prompt,
).images
assert images.shape == (num_images_per_prompt, 128, 128, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
images = sd_pipe(
init_image.repeat(batch_size, 1, 1, 1),
num_inference_steps=2,
output_type="np",
num_images_per_prompt=num_images_per_prompt,
).images
assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3)
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_img_variation_fp16(self):
"""Test that stable diffusion img2img works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
image_encoder = self.dummy_image_encoder
init_image = self.dummy_image.to(torch_device).float()
# put models in fp16
unet = unet.half()
vae = vae.half()
image_encoder = image_encoder.half()
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImageVariationPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
image_encoder=image_encoder,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe(
init_image,
generator=generator,
num_inference_steps=2,
output_type="np",
).images
assert image.shape == (1, 128, 128, 3)
@slow
@require_torch_gpu
class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_img_variation_pipeline_default(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.jpg"
)
init_image = init_image.resize((512, 512))
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.npy"
)
model_id = "fusing/sd-image-variations-diffusers"
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
model_id,
safety_checker=None,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
init_image,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 512, 3)
# img2img is flaky across GPUs even in fp32, so using MAE here
assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_img_variation_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
elif step == 37:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([2.285, 2.703, 1.969, 0.696, -1.323, 0.9253, -0.5464, -1.521, -2.537])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
test_callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((512, 512))
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
"fusing/sd-image-variations-diffusers",
torch_dtype=torch.float16,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
init_image,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 51
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((512, 512))
model_id = "fusing/sd-image-variations-diffusers"
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
model_id, scheduler=lms, safety_checker=None, torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe(
init_image,
guidance_scale=7.5,
generator=generator,
output_type="np",
num_inference_steps=5,
)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.6 GB is allocated
assert mem_bytes < 2.6 * 10**9