[StableDiffusionInpaintPipeline] accept tensors for init and mask image (#439)

* accept tensors

* fix mask handling

* make device placement cleaner

* update doc for mask image
This commit is contained in:
Suraj Patil 2022-09-16 17:35:41 +02:00 committed by GitHub
parent 761f0297b0
commit 06924c6a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 6 deletions

View File

@ -145,8 +145,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
process. This is the image whose masked region will be inpainted.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
converted to a single channel (luminance) before use.
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
@ -202,10 +203,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# preprocess image
init_image = preprocess_image(init_image).to(self.device)
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
init_image.to(self.device)
# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
@ -215,8 +218,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_latents_orig = init_latents
# preprocess mask
mask = preprocess_mask(mask_image).to(self.device)
mask = torch.cat([mask] * batch_size)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size)
# check sizes
if not mask.shape == init_latents.shape: