diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index 6068b961..f16c2cba 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -97,6 +97,14 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention +## StableDiffusionDepth2ImgPipeline +[[autodoc]] StableDiffusionDepth2ImgPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention + ## StableDiffusionImageVariationPipeline [[autodoc]] StableDiffusionImageVariationPipeline - __call__ diff --git a/docs/source/api/pipelines/stable_diffusion_2.mdx b/docs/source/api/pipelines/stable_diffusion_2.mdx index 5df91950..baf40be1 100644 --- a/docs/source/api/pipelines/stable_diffusion_2.mdx +++ b/docs/source/api/pipelines/stable_diffusion_2.mdx @@ -30,6 +30,7 @@ Note that the architecture is more or less identical to [Stable Diffusion 1](./a - *Text-to-Image (768x768 resolution)*: [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) with [`StableDiffusionPipeline`] - *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`] - *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`] +- *Depth-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth) with [`StableDiffusionDepth2ImagePipeline`] We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest scheduler there is. @@ -125,6 +126,37 @@ upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] upscaled_image.save("upsampled_cat.png") ``` +- *Depth-Guided Text-to-Image*: [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth) [`StableDiffusionDepth2ImagePipeline`] + +**Installation** + +```bash +!pip install -U git+https://github.com/huggingface/transformers.git +!pip install diffusers[torch] +``` + +**Example** + +```python +import torch +import requests +from PIL import Image + +from diffusers import StableDiffusionDepth2ImgPipeline + +pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-depth", + torch_dtype=torch.float16, +).to("cuda") + + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +init_image = Image.open(requests.get(url, stream=True).raw) +prompt = "two tigers" +n_propmt = "bad, deformed, ugly, bad anotomy" +image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0] +``` + ### How to load and use different schedulers. The stable diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. diff --git a/setup.py b/setup.py index 64836d08..d6a16744 100644 --- a/setup.py +++ b/setup.py @@ -107,7 +107,7 @@ _deps = [ "tensorboard", "torch>=1.4", "torchvision", - "transformers>=4.21.0", + "transformers>=4.25.1", ] # this is a lookup table with items like: diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3b82c5a6..5a3c0c43 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -12,11 +12,24 @@ from .utils import ( is_scipy_available, is_torch_available, is_transformers_available, + is_transformers_version, is_unidecode_available, logging, ) +# Make sure `transformers` is up to date +if is_transformers_available(): + import transformers + + if is_transformers_version("<", "4.25.1"): + raise ImportError( + f"`diffusers` requires transformers >= 4.25.1 to function correctly, but {transformers.__version__} was" + " found in your environment. You can upgrade it with pip: `pip install transformers --upgrade`" + ) +else: + pass + try: if not is_torch_available(): raise OptionalDependencyNotAvailable() @@ -87,6 +100,7 @@ else: CycleDiffusionPipeline, LDMTextToImagePipeline, PaintByExamplePipeline, + StableDiffusionDepth2ImgPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index b5f9d6d3..1ef1edc1 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -31,5 +31,5 @@ deps = { "tensorboard": "tensorboard", "torch": "torch>=1.4", "torchvision": "torchvision", - "transformers": "transformers>=4.21.0", + "transformers": "transformers>=4.25.1", } diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 89605ccd..9f7d1a05 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -44,6 +44,7 @@ else: from .paint_by_example import PaintByExamplePipeline from .stable_diffusion import ( CycleDiffusionPipeline, + StableDiffusionDepth2ImgPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index ac544cbe..2e92dfa3 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -46,13 +46,23 @@ if is_transformers_available() and is_torch_available(): from .safety_checker import StableDiffusionSafetyChecker try: - if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")): + if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline else: from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0.dev0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline +else: + from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline + + try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py new file mode 100644 index 00000000..f8751410 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -0,0 +1,564 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, deprecate, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + depth_estimator: DPTForDepthEstimation, + feature_extractor: DPTFeatureExtractor, + ): + super().__init__() + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + depth_estimator=depth_estimator, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.depth_estimator]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + image = image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): + if isinstance(image, PIL.Image.Image): + width, height = image.size + width, height = map(lambda dim: dim - dim % 32, (width, height)) # resize to integer multiple of 32 + image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + width, height = image.size + else: + image = [img for img in image] + width, height = image[0].shape[-2:] + + if depth_map is None: + pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=device) + # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. + # So we use `torch.autocast` here for half precision inference. + context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext() + with context_manger: + depth_map = self.depth_estimator(pixel_values).predicted_depth + else: + depth_map = depth_map.to(device=device, dtype=dtype) + + depth_map = torch.nn.functional.interpolate( + depth_map.unsqueeze(1), + size=(height // self.vae_scale_factor, width // self.vae_scale_factor), + mode="bicubic", + align_corners=False, + ) + + depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 + depth_map = depth_map.to(dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if depth_map.shape[0] < batch_size: + depth_map = depth_map.repeat(batch_size, 1, 1, 1) + + depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map + return depth_map + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + depth_map: Optional[torch.FloatTensor] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare depth mask + depth_mask = self.prepare_depth_map( + image, + depth_map, + batch_size * num_images_per_prompt, + do_classifier_free_guidance, + text_embeddings.dtype, + device, + ) + + # 5. Preprocess image + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + else: + image = 2.0 * (image / 255.0) - 1.0 + + # 6. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 7. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + image = self.decode_latents(latents) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/__init__.py b/src/diffusers/pipelines/versatile_diffusion/__init__.py index 3c4b5208..6abda997 100644 --- a/src/diffusers/pipelines/versatile_diffusion/__init__.py +++ b/src/diffusers/pipelines/versatile_diffusion/__init__.py @@ -7,7 +7,7 @@ from ...utils import ( try: - if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")): + if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index f5a2e55d..160b83a7 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -79,6 +79,21 @@ class PaintByExamplePipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 6ebdf7d9..ef934aff 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -310,12 +310,6 @@ LIBROSA_IMPORT_ERROR = """ installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. """ -# docstyle-ignore -TENSORFLOW_IMPORT_ERROR = """ -{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the -installation page: https://www.tensorflow.org/install and follow the ones that match your environment. -""" - # docstyle-ignore TRANSFORMERS_IMPORT_ERROR = """ {0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip @@ -341,7 +335,6 @@ BACKENDS_MAPPING = OrderedDict( ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), @@ -361,12 +354,7 @@ def requires_backends(obj, backends): if failed: raise ImportError("".join(failed)) - if name in [ - "VersatileDiffusionTextToImagePipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionDualGuidedPipeline", - "StableDiffusionImageVariationPipeline", - ] and is_transformers_version("<", "4.25.0.dev0"): + if name in ["StableDiffusionDepth2ImgPipeline"] and is_transformers_version("<", "4.26.0.dev0"): raise ImportError( f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install" " git+https://github.com/huggingface/transformers \n```" diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py new file mode 100644 index 00000000..df074f6c --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -0,0 +1,573 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionDepth2ImgPipeline, + UNet2DConditionModel, +) +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.utils.testing_utils import require_torch_gpu +from PIL import Image +from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + DPTConfig, + DPTFeatureExtractor, + DPTForDepthEstimation, +) + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") +class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionDepth2ImgPipeline + test_save_load_optional_components = False + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=5, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + attention_head_dim=(2, 4, 8, 8), + use_linear_projection=True, + ) + scheduler = PNDMScheduler(skip_prk_steps=True) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + backbone_config = { + "global_padding": "same", + "layer_type": "bottleneck", + "depths": [3, 4, 9], + "out_features": ["stage1", "stage2", "stage3"], + "embedding_dynamic_padding": True, + "hidden_sizes": [96, 192, 384, 768], + "num_groups": 2, + } + depth_estimator_config = DPTConfig( + image_size=32, + patch_size=16, + num_channels=3, + hidden_size=32, + num_hidden_layers=4, + backbone_out_indices=(0, 1, 2, 3), + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + is_decoder=False, + initializer_range=0.02, + is_hybrid=True, + backbone_config=backbone_config, + backbone_featmap_shape=[1, 384, 24, 24], + ) + depth_estimator = DPTForDepthEstimation(depth_estimator_config) + feature_extractor = DPTFeatureExtractor.from_pretrained( + "hf-internal-testing/tiny-random-DPTForDepthEstimation" + ) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "depth_estimator": depth_estimator, + "feature_extractor": feature_extractor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs + + @unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") + def test_save_load_local(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output - output_loaded).max() + self.assertLess(max_diff, 3e-5) + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self): + components = self.get_dummy_components() + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output - output_loaded).max() + self.assertLess(max_diff, 2e-2, "The output of the fp16 pipeline changed after saving and loading.") + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_float16_inference(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.half() + pipe_fp16 = self.pipeline_class(**components) + pipe_fp16.to(torch_device) + pipe_fp16.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(torch_device))[0] + output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0] + + max_diff = np.abs(output - output_fp16).max() + self.assertLess(max_diff, 1.3e-2, "The outputs of the fp16 and fp32 pipelines are too different.") + + @unittest.skipIf( + torch_device != "cuda" or not is_accelerate_available(), + reason="CPU offload is only available with CUDA and `accelerate` installed", + ) + def test_cpu_offload_forward_pass(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_without_offload = pipe(**inputs)[0] + + pipe.enable_sequential_cpu_offload() + inputs = self.get_dummy_inputs(torch_device) + output_with_offload = pipe(**inputs)[0] + + max_diff = np.abs(output_with_offload - output_without_offload).max() + self.assertLess(max_diff, 3e-5, "CPU offloading should not affect the inference results") + + @unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") + def test_dict_tuple_outputs_equivalent(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = pipe(**self.get_dummy_inputs(torch_device)) + + output = pipe(**self.get_dummy_inputs(torch_device))[0] + output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0] + + max_diff = np.abs(output - output_tuple).max() + self.assertLess(max_diff, 3e-5) + + @unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") + def test_num_inference_steps_consistent(self): + super().test_num_inference_steps_consistent() + + @unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") + def test_progress_bar(self): + super().test_progress_bar() + + def test_stable_diffusion_depth2img_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionDepth2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + if torch_device == "mps": + expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) + else: + expected_slice = np.array([0.6907, 0.5135, 0.4688, 0.5169, 0.5738, 0.4600, 0.4435, 0.5640, 0.4653]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_depth2img_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionDepth2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + negative_prompt = "french fries" + output = sd_pipe(**inputs, negative_prompt=negative_prompt) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + if torch_device == "mps": + expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) + else: + expected_slice = np.array([0.755, 0.521, 0.473, 0.554, 0.629, 0.442, 0.440, 0.582, 0.449]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_depth2img_multiple_init_images(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionDepth2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"]] * 2 + inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) + image = sd_pipe(**inputs).images + image_slice = image[-1, -3:, -3:, -1] + + assert image.shape == (2, 32, 32, 3) + + if torch_device == "mps": + expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) + else: + expected_slice = np.array([0.6475, 0.6302, 0.5627, 0.5222, 0.4318, 0.5489, 0.5079, 0.4419, 0.4494]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_depth2img_num_images_per_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionDepth2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + # test num_images_per_prompt=1 (default) + inputs = self.get_dummy_inputs(device) + images = sd_pipe(**inputs).images + + assert images.shape == (1, 32, 32, 3) + + # test num_images_per_prompt=1 (default) for batch of prompts + batch_size = 2 + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"]] * batch_size + images = sd_pipe(**inputs).images + + assert images.shape == (batch_size, 32, 32, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + inputs = self.get_dummy_inputs(device) + images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + + assert images.shape == (num_images_per_prompt, 32, 32, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"]] * batch_size + images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + + assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + + def test_stable_diffusion_depth2img_pil(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionDepth2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + + inputs["image"] = Image.fromarray(inputs["image"][0].permute(1, 2, 0).numpy().astype(np.uint8)) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + if torch_device == "mps": + expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) + else: + expected_slice = np.array([0.6853, 0.3740, 0.4856, 0.7130, 0.7402, 0.5535, 0.4828, 0.6182, 0.5053]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + +@slow +@require_torch_gpu +class StableDiffusionDepth2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_depth2img_pipeline_default(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-depth" + pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(model_id) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "two tigers" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + strength=0.75, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (480, 640, 3) + # depth2img is flaky across GPUs even in fp32, so using MAE here + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_depth2img_pipeline_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats_k_lms.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-depth" + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(model_id, scheduler=lms) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "two tigers" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + strength=0.75, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (480, 640, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_depth2img_pipeline_ddim(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats_ddim.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-depth" + ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(model_id, scheduler=ddim) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "two tigers" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + strength=0.75, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (480, 640, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_depth2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 1: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 60, 80) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.7825, 0.5786, -0.9125, -0.9885, -1.0071, 2.7126, -0.8490, 0.3776, -0.0791] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 60, 80) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.6110, -0.2347, -0.5115, -1.1383, -1.4755, -0.5970, -0.9050, -0.7199, -0.8417] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png" + ) + + pipe = StableDiffusionDepth2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-depth") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "two tigers" + + generator = torch.Generator(device=torch_device).manual_seed(0) + pipe( + prompt=prompt, + image=init_image, + strength=0.75, + num_inference_steps=50, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 37 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/depth2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + + model_id = "stabilityai/stable-diffusion-2-depth" + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( + model_id, scheduler=lms, safety_checker=None, revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + num_inference_steps=5, + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.9 GB is allocated + assert mem_bytes < 2.9 * 10**9 diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index f328a440..3a6d5139 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -26,7 +26,6 @@ import torch import PIL import safetensors.torch -import transformers from diffusers import ( AutoencoderKL, DDIMPipeline, @@ -533,9 +532,8 @@ class PipelineFastTests(unittest.TestCase): # Validate that the text encoder safetensor exists and are of the correct format text_encoder_path = os.path.join(tmpdirname, "text_encoder", "model.safetensors") - if transformers.__version__ >= "4.25.1": - assert os.path.exists(text_encoder_path), f"Could not find {text_encoder_path}" - _ = safetensors.torch.load_file(text_encoder_path) + assert os.path.exists(text_encoder_path), f"Could not find {text_encoder_path}" + _ = safetensors.torch.load_file(text_encoder_path) pipeline = StableDiffusionPipeline.from_pretrained(tmpdirname) assert pipeline.unet is not None diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index af5dd6a4..93f8edb8 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -11,7 +11,13 @@ from typing import Callable, Union import numpy as np import torch -from diffusers import CycleDiffusionPipeline, DanceDiffusionPipeline, DiffusionPipeline, StableDiffusionImg2ImgPipeline +from diffusers import ( + CycleDiffusionPipeline, + DanceDiffusionPipeline, + DiffusionPipeline, + StableDiffusionDepth2ImgPipeline, + StableDiffusionImg2ImgPipeline, +) from diffusers.utils.import_utils import is_accelerate_available, is_xformers_available from diffusers.utils.testing_utils import require_torch, torch_device @@ -281,6 +287,7 @@ class PipelineTesterMixin: DanceDiffusionPipeline, CycleDiffusionPipeline, StableDiffusionImg2ImgPipeline, + StableDiffusionDepth2ImgPipeline, ): # FIXME: inconsistent outputs on MPS return diff --git a/utils/check_dummies.py b/utils/check_dummies.py index 88b26682..38fccca1 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -91,7 +91,8 @@ def read_init(): objects.append(line[8:-2]) line_index += 1 - backend_specific_objects[backend] = objects + if len(objects) > 0: + backend_specific_objects[backend] = objects else: line_index += 1