Stable diffusion inpainting. (#904)

* begin pipe

* add new pipeline

* add tests

* correct fast test

* up

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

* Update tests/test_pipelines.py

* up

* up

* make style

* add fp16 test

* doc, comments

* up

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Suraj Patil 2022-10-19 16:11:50 +02:00 committed by GitHub
parent 83b696e6c0
commit b35d88c536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 332 additions and 110 deletions

View File

@ -5,7 +5,6 @@ import numpy as np
import torch
import PIL
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@ -17,30 +16,24 @@ from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
return mask
masked_image = image * (mask < 0.5)
return mask, masked_image
class StableDiffusionInpaintPipeline(DiffusionPipeline):
@ -82,6 +75,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
@ -140,22 +134,24 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `set_attention_slice`
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@ -168,22 +164,21 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`.
mask_image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@ -201,6 +196,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@ -221,7 +220,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# TODO(Suraj) - adapt to your use case
if isinstance(prompt, str):
batch_size = 1
@ -230,8 +228,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
@ -241,9 +239,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
@ -262,8 +257,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@ -300,50 +297,78 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
# encode the init image into latents and scale the latents
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
num_channels_latents = self.vae.config.latent_channels
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# prepare mask and masked_image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
# resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(num_images_per_prompt, 1, 1, 1)
masked_image_latents = masked_image_latents.repeat(num_images_per_prompt, 1, 1, 1)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@ -354,17 +379,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in tqdm(enumerate(timesteps)):
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@ -377,10 +398,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
@ -390,13 +407,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None

View File

@ -46,6 +46,7 @@ from diffusers import (
ScoreSdeVeScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
UNet2DConditionModel,
UNet2DModel,
@ -189,6 +190,21 @@ class PipelineFastTests(unittest.TestCase):
)
return model
@property
def dummy_cond_unet_inpaint(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=9,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model
@property
def dummy_vq_model(self):
torch.manual_seed(0)
@ -897,7 +913,7 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint(self):
def test_stable_diffusion_inpaint_legacy(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
@ -910,7 +926,7 @@ class PipelineFastTests(unittest.TestCase):
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
sd_pipe = StableDiffusionInpaintPipelineLegacy(
unet=unet,
scheduler=scheduler,
vae=vae,
@ -956,7 +972,66 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_negative_prompt(self):
def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet_inpaint
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
image=init_image,
mask_image=mask_image,
)
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
image=init_image,
mask_image=mask_image,
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5075, 0.4485, 0.4558, 0.5369, 0.5369, 0.5236, 0.5127, 0.4983, 0.4776])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_legacy_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
@ -969,7 +1044,7 @@ class PipelineFastTests(unittest.TestCase):
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
sd_pipe = StableDiffusionInpaintPipelineLegacy(
unet=unet,
scheduler=scheduler,
vae=vae,
@ -1122,7 +1197,7 @@ class PipelineFastTests(unittest.TestCase):
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
def test_stable_diffusion_inpaint_num_images_per_prompt(self):
def test_stable_diffusion_inpaint_legacy_num_images_per_prompt(self):
device = "cpu"
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
@ -1135,7 +1210,7 @@ class PipelineFastTests(unittest.TestCase):
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
sd_pipe = StableDiffusionInpaintPipelineLegacy(
unet=unet,
scheduler=scheduler,
vae=vae,
@ -1274,15 +1349,15 @@ class PipelineFastTests(unittest.TestCase):
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_inpaint_fp16(self):
"""Test that stable diffusion inpaint works with fp16"""
unet = self.dummy_cond_unet
"""Test that stable diffusion inpaint_legacy works with fp16"""
unet = self.dummy_cond_unet_inpaint
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# put models in fp16
@ -1297,8 +1372,8 @@ class PipelineFastTests(unittest.TestCase):
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
safety_checker=None,
feature_extractor=None,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
@ -1310,11 +1385,11 @@ class PipelineFastTests(unittest.TestCase):
generator=generator,
num_inference_steps=2,
output_type="np",
init_image=init_image,
image=init_image,
mask_image=mask_image,
).images
assert image.shape == (1, 32, 32, 3)
assert image.shape == (1, 128, 128, 3)
class PipelineTesterMixin(unittest.TestCase):
@ -1924,6 +1999,90 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_pipeline(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/yellow_cat_sitting_on_a_park_bench.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
model_id = "fusing/sd-inpaint-temp"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=self.dummy_safety_checker,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_pipeline_fp16(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/yellow_cat_sitting_on_a_park_bench.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
model_id = "fusing/sd-inpaint-temp"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
revision="fp16",
torch_dtype=torch.float16,
safety_checker=None,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_legacy_pipeline(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
@ -1966,7 +2125,49 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_pipeline_k_lms(self):
def test_stable_diffusion_inpaint_pipeline_pndm(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/yellow_cat_sitting_on_a_park_bench_pndm.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True)
model_id = "fusing/sd-inpaint-temp"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, safety_checker=self.dummy_safety_checker, scheduler=pndm
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_legacy_pipeline_k_lms(self):
# TODO(Anton, Patrick) - I think we can remove this test soon
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
@ -2199,7 +2400,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_intermediate_state(self):
def test_stable_diffusion_inpaint_legacy_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: