Add callback parameters for Stable Diffusion pipelines (#521)

* Add callback parameters for Stable Diffusion pipelines

Signed-off-by: James R T <jamestiotio@gmail.com>

* Lint code with `black --preview`

Signed-off-by: James R T <jamestiotio@gmail.com>

* Refactor callback implementation for Stable Diffusion pipelines

* Fix missing imports

Signed-off-by: James R T <jamestiotio@gmail.com>

* Fix documentation format

Signed-off-by: James R T <jamestiotio@gmail.com>

* Add kwargs parameter to standardize with other pipelines

Signed-off-by: James R T <jamestiotio@gmail.com>

* Modify Stable Diffusion pipeline callback parameters

Signed-off-by: James R T <jamestiotio@gmail.com>

* Remove useless imports

Signed-off-by: James R T <jamestiotio@gmail.com>

* Change types for timestep and onnx latents

* Fix docstring style

* Return decode_latents and run_safety_checker back into __call__

* Remove unused imports

* Add intermediate state tests for Stable Diffusion pipelines

Signed-off-by: James R T <jamestiotio@gmail.com>

* Fix intermediate state tests for Stable Diffusion pipelines

Signed-off-by: James R T <jamestiotio@gmail.com>

Signed-off-by: James R T <jamestiotio@gmail.com>
This commit is contained in:
James R T 2022-10-03 01:56:36 +08:00 committed by GitHub
parent 5156acc476
commit 2558977bc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 259 additions and 12 deletions

View File

@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import torch
@ -122,6 +122,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
@ -159,6 +161,12 @@ class StableDiffusionPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@ -178,6 +186,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
@ -277,14 +293,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)

View File

@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import numpy as np
import torch
@ -133,6 +133,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@ -170,6 +173,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@ -188,6 +197,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
@ -265,6 +282,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
@ -295,14 +313,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)

View File

@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import numpy as np
import torch
@ -149,6 +149,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@ -190,6 +193,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@ -208,6 +217,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
@ -297,7 +314,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
extra_step_kwargs["eta"] = eta
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
@ -331,14 +350,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# scale and decode the image latents with vae
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)

View File

@ -1,5 +1,5 @@
import inspect
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import numpy as np
@ -56,6 +56,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
if isinstance(prompt, str):
@ -68,6 +70,14 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
@ -151,14 +161,18 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = np.array(latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)

View File

@ -1435,3 +1435,177 @@ class PipelineTesterMixin(unittest.TestCase):
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Andromeda galaxy in a bottle"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 51
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 96)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
init_image=init_image,
strength=0.75,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 38
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A red cat sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
init_image=init_image,
mask_image=mask_image,
strength=0.75,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 38
@slow
def test_stable_diffusion_onnx_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
test_callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.6254, -0.2742, -1.0710, 0.2296, -1.1683, 0.6913, -2.0605, -0.0682, 0.9700]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)
prompt = "Andromeda galaxy in a bottle"
np.random.seed(0)
pipe(prompt=prompt, num_inference_steps=50, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
assert test_callback_fn.has_been_called
assert number_of_steps == 51