Add callback parameters for Stable Diffusion pipelines (#521)
* Add callback parameters for Stable Diffusion pipelines Signed-off-by: James R T <jamestiotio@gmail.com> * Lint code with `black --preview` Signed-off-by: James R T <jamestiotio@gmail.com> * Refactor callback implementation for Stable Diffusion pipelines * Fix missing imports Signed-off-by: James R T <jamestiotio@gmail.com> * Fix documentation format Signed-off-by: James R T <jamestiotio@gmail.com> * Add kwargs parameter to standardize with other pipelines Signed-off-by: James R T <jamestiotio@gmail.com> * Modify Stable Diffusion pipeline callback parameters Signed-off-by: James R T <jamestiotio@gmail.com> * Remove useless imports Signed-off-by: James R T <jamestiotio@gmail.com> * Change types for timestep and onnx latents * Fix docstring style * Return decode_latents and run_safety_checker back into __call__ * Remove unused imports * Add intermediate state tests for Stable Diffusion pipelines Signed-off-by: James R T <jamestiotio@gmail.com> * Fix intermediate state tests for Stable Diffusion pipelines Signed-off-by: James R T <jamestiotio@gmail.com> Signed-off-by: James R T <jamestiotio@gmail.com>
This commit is contained in:
parent
5156acc476
commit
2558977bc7
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -122,6 +122,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
@ -159,6 +161,12 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
|
@ -178,6 +186,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
|
@ -277,14 +293,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -133,6 +133,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
@ -170,6 +173,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
|
@ -188,6 +197,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
|
@ -265,6 +282,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimzed to move all timesteps to correct device beforehand
|
||||
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
|
||||
|
@ -295,14 +313,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -149,6 +149,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
@ -190,6 +193,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
|
@ -208,6 +217,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
|
@ -297,7 +314,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimzed to move all timesteps to correct device beforehand
|
||||
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
|
||||
|
@ -331,14 +350,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -56,6 +56,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
|||
latents: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(prompt, str):
|
||||
|
@ -68,6 +70,14 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
|||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
|
@ -151,14 +161,18 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
|||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = np.array(latents)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae_decoder(latent_sample=latents)[0]
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
|
||||
# run safety checker
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
|
||||
|
||||
|
|
|
@ -1435,3 +1435,177 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_text2img_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
callback=test_callback_fn,
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_img2img_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 96)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
callback=test_callback_fn,
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 38
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_inpaint_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.75,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
callback=test_callback_fn,
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 38
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.6254, -0.2742, -1.0710, 0.2296, -1.1683, 0.6913, -2.0605, -0.0682, 0.9700]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionOnnxPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
np.random.seed(0)
|
||||
pipe(prompt=prompt, num_inference_steps=50, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
|
|
Loading…
Reference in New Issue