[StableDiffusionInpaintPipeline] accept tensors for init and mask image (#439)
* accept tensors * fix mask handling * make device placement cleaner * update doc for mask image
This commit is contained in:
parent
761f0297b0
commit
06924c6a4f
|
@ -145,8 +145,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
process. This is the image whose masked region will be inpainted.
|
||||
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
|
||||
replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
|
||||
converted to a single channel (luminance) before use.
|
||||
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
||||
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
||||
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
||||
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
||||
|
@ -202,10 +203,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
# preprocess image
|
||||
init_image = preprocess_image(init_image).to(self.device)
|
||||
if not isinstance(init_image, torch.FloatTensor):
|
||||
init_image = preprocess_image(init_image)
|
||||
init_image.to(self.device)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
@ -215,8 +218,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
init_latents_orig = init_latents
|
||||
|
||||
# preprocess mask
|
||||
mask = preprocess_mask(mask_image).to(self.device)
|
||||
mask = torch.cat([mask] * batch_size)
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = preprocess_mask(mask_image)
|
||||
mask_image.to(self.device)
|
||||
mask = torch.cat([mask_image] * batch_size)
|
||||
|
||||
# check sizes
|
||||
if not mask.shape == init_latents.shape:
|
||||
|
|
Loading…
Reference in New Issue