From 2558977bc7dea6c72b9d619a471b8cd4d6cc1594 Mon Sep 17 00:00:00 2001 From: James R T Date: Mon, 3 Oct 2022 01:56:36 +0800 Subject: [PATCH] Add callback parameters for Stable Diffusion pipelines (#521) * Add callback parameters for Stable Diffusion pipelines Signed-off-by: James R T * Lint code with `black --preview` Signed-off-by: James R T * Refactor callback implementation for Stable Diffusion pipelines * Fix missing imports Signed-off-by: James R T * Fix documentation format Signed-off-by: James R T * Add kwargs parameter to standardize with other pipelines Signed-off-by: James R T * Modify Stable Diffusion pipeline callback parameters Signed-off-by: James R T * Remove useless imports Signed-off-by: James R T * Change types for timestep and onnx latents * Fix docstring style * Return decode_latents and run_safety_checker back into __call__ * Remove unused imports * Add intermediate state tests for Stable Diffusion pipelines Signed-off-by: James R T * Fix intermediate state tests for Stable Diffusion pipelines Signed-off-by: James R T Signed-off-by: James R T --- .../pipeline_stable_diffusion.py | 24 ++- .../pipeline_stable_diffusion_img2img.py | 26 ++- .../pipeline_stable_diffusion_inpaint.py | 27 ++- .../pipeline_stable_diffusion_onnx.py | 20 +- tests/test_pipelines.py | 174 ++++++++++++++++++ 5 files changed, 259 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5c6890db..d5407e28 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import torch @@ -122,6 +122,8 @@ class StableDiffusionPipeline(DiffusionPipeline): latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, ): r""" @@ -159,6 +161,12 @@ class StableDiffusionPipeline(DiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -178,6 +186,14 @@ class StableDiffusionPipeline(DiffusionPipeline): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -277,14 +293,16 @@ class StableDiffusionPipeline(DiffusionPipeline): else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # scale and decode the image latents with vae + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 8f2bb67b..63178772 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -133,6 +133,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -170,6 +173,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -188,6 +197,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -265,6 +282,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) + # Some schedulers like PNDM have timesteps as arrays # It's more optimzed to move all timesteps to correct device beforehand timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) @@ -295,14 +313,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # scale and decode the image latents with vae + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 2e792df1..37f03dfb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -149,6 +149,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -190,6 +193,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -208,6 +217,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -297,7 +314,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): extra_step_kwargs["eta"] = eta latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) + # Some schedulers like PNDM have timesteps as arrays # It's more optimzed to move all timesteps to correct device beforehand timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) @@ -331,14 +350,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): latents = (init_latents_proper * mask) + (latents * (1 - mask)) - # scale and decode the image latents with vae + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 07e9c1d9..48e0e147 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np @@ -56,6 +56,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, ): if isinstance(prompt, str): @@ -68,6 +70,14 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -151,14 +161,18 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # scale and decode the image latents with vae + latents = np.array(latents) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + latents = 1 / 0.18215 * latents image = self.vae_decoder(latent_sample=latents)[0] image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) - # run safety checker safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4ce82cc7..d0d78171 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1435,3 +1435,177 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_text2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Andromeda galaxy in a bottle" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 51 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_img2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 38 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_inpaint_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A red cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 38 + + @slow + def test_stable_diffusion_onnx_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.6254, -0.2742, -1.0710, 0.2296, -1.1683, 0.6913, -2.0605, -0.0682, 0.9700] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionOnnxPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CPUExecutionProvider" + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "Andromeda galaxy in a bottle" + + np.random.seed(0) + pipe(prompt=prompt, num_inference_steps=50, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1) + assert test_callback_fn.has_been_called + assert number_of_steps == 51