Enables the original functionality to be toggled on and off.

This commit is contained in:
CodeHatchling 2023-12-04 01:57:21 -07:00
parent aaacf48232
commit 259d33c3c8
1 changed files with 66 additions and 25 deletions

View File

@ -88,9 +88,12 @@ def apply_overlay(image, paste_loc, index, overlays):
return image return image
def create_binary_mask(image): def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
image = image.split()[-1].convert("L") if round:
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
else:
image = image.split()[-1].convert("L")
else: else:
image = image.convert('L') image = image.convert('L')
return image return image
@ -316,7 +319,7 @@ class StableDiffusionProcessing:
c_adm = torch.cat((c_adm, noise_level_emb), 1) c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm return c_adm
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True self.is_using_inpainting_conditioning = True
# Handle the different mask inputs # Handle the different mask inputs
@ -327,6 +330,11 @@ class StableDiffusionProcessing:
conditioning_mask = np.array(image_mask.convert("L")) conditioning_mask = np.array(image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
if round_image_mask:
# Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else: else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
@ -350,7 +358,7 @@ class StableDiffusionProcessing:
return image_conditioning return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image) source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
@ -362,7 +370,10 @@ class StableDiffusionProcessing:
return self.edit_image_conditioning(source_image) return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) return self.inpainting_image_conditioning(source_image,
latent_image,
image_mask=image_mask,
round_image_mask=round_image_mask)
if self.sampler.conditioning_key == "crossattn-adm": if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image) return self.unclip_image_conditioning(source_image)
@ -878,8 +889,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
if opts.sd_vae_decode_method != 'Full': if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
# Generate the mask(s) based on similarity between the original and denoised latent vectors # Generate the mask(s) based on similarity between the original and denoised latent vectors
if getattr(p, "image_mask", None) is not None: if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
# latent_mask = p.nmask[0].float().cpu() # latent_mask = p.nmask[0].float().cpu()
# convert the original mask into a form we use to scale distances for thresholding # convert the original mask into a form we use to scale distances for thresholding
@ -911,7 +923,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
converted_mask = converted_mask.astype(np.uint8) converted_mask = converted_mask.astype(np.uint8)
converted_mask = Image.fromarray(converted_mask) converted_mask = Image.fromarray(converted_mask)
converted_mask = images.resize_image(2, converted_mask, p.width, p.height) converted_mask = images.resize_image(2, converted_mask, p.width, p.height)
converted_mask = create_binary_mask(converted_mask) converted_mask = create_binary_mask(converted_mask, round=False)
# Remove aliasing artifacts using a gaussian blur. # Remove aliasing artifacts using a gaussian blur.
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
@ -1010,23 +1022,33 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): if save_samples and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.masks_for_overlay[i].convert('RGB') if hasattr(p, 'masks_for_overlay') and p.masks_for_overlay:
image_mask_composite = Image.composite( image_mask = p.masks_for_overlay[i].convert('RGB')
original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), image_mask_composite = Image.composite(
images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size),
images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA')
elif hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(
original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size),
images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
else:
image_mask = None
image_mask_composite = None
if opts.save_mask: if image_mask is not None and image_mask_composite is not None:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.save_mask_composite: if opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask: if opts.return_mask:
output_images.append(image_mask) output_images.append(image_mask)
if opts.return_mask_composite: if opts.return_mask_composite:
output_images.append(image_mask_composite) output_images.append(image_mask_composite)
del x_samples_ddim del x_samples_ddim
@ -1439,6 +1461,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
nmask: torch.Tensor = field(default=None, init=False) nmask: torch.Tensor = field(default=None, init=False)
image_conditioning: torch.Tensor = field(default=None, init=False) image_conditioning: torch.Tensor = field(default=None, init=False)
init_img_hash: str = field(default=None, init=False) init_img_hash: str = field(default=None, init=False)
mask_for_overlay: Image = field(default=None, init=False)
init_latent: torch.Tensor = field(default=None, init=False) init_latent: torch.Tensor = field(default=None, init=False)
def __post_init__(self): def __post_init__(self):
@ -1471,7 +1494,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if image_mask is not None: if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks, # image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks. # but we still want to support binary masks.
image_mask = create_binary_mask(image_mask) image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None))
if self.inpainting_mask_invert: if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask) image_mask = ImageOps.invert(image_mask)
@ -1489,6 +1512,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = Image.fromarray(np_mask) image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res: if self.inpaint_full_res:
self.mask_for_overlay = image_mask if self.soft_inpainting is None else None
mask = image_mask.convert('L') mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@ -1500,7 +1524,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
else: else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
self.masks_for_overlay = [] if self.soft_inpainting is None:
np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
self.masks_for_overlay = [] if self.soft_inpainting is not None else None
self.overlay_images = [] self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@ -1522,8 +1551,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = images.resize_image(self.resize_mode, image, self.width, self.height) image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None: if image_mask is not None:
self.overlay_images.append(image) if self.soft_inpainting is not None:
self.masks_for_overlay.append(image_mask) # We apply the masks AFTER to adjust mask based on changed content.
self.overlay_images.append(image)
self.masks_for_overlay.append(image_mask)
else:
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
# crop_region is not None if we are doing inpaint full res # crop_region is not None if we are doing inpaint full res
if crop_region is not None: if crop_region is not None:
@ -1576,6 +1612,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0] latmask = latmask[0]
if self.soft_inpainting is None:
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1)) latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
@ -1587,7 +1625,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3: elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask self.init_latent = self.init_latent * self.mask
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1,
self.init_latent,
image_mask,
self.soft_inpainting is None)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next() x = self.rng.next()