From 06924c6a4fa3001973048a3050c37625b86ad066 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 16 Sep 2022 17:35:41 +0200 Subject: [PATCH] [StableDiffusionInpaintPipeline] accept tensors for init and mask image (#439) * accept tensors * fix mask handling * make device placement cleaner * update doc for mask image --- .../pipeline_stable_diffusion_inpaint.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index e97e5207..6076b854 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -145,8 +145,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): process. This is the image whose masked region will be inpainted. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be - converted to a single channel (luminance) before use. + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified @@ -202,10 +203,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # preprocess image - init_image = preprocess_image(init_image).to(self.device) + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess_image(init_image) + init_image.to(self.device) # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents @@ -215,8 +218,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): init_latents_orig = init_latents # preprocess mask - mask = preprocess_mask(mask_image).to(self.device) - mask = torch.cat([mask] * batch_size) + if not isinstance(mask_image, torch.FloatTensor): + mask_image = preprocess_mask(mask_image) + mask_image.to(self.device) + mask = torch.cat([mask_image] * batch_size) # check sizes if not mask.shape == init_latents.shape: