Blend masks are now produced afterward, based on an estimate of the visual difference between the original and modified latent images. This should remove ghosting and clipping artifacts from masks, while preserving the details of largely unchanged content.

This commit is contained in:
CodeHatchling 2023-12-02 21:07:02 -07:00
parent 609dea36ea
commit 73ab982d1b
1 changed files with 90 additions and 29 deletions

View File

@ -9,7 +9,7 @@ from dataclasses import dataclass, field
import torch
import numpy as np
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageFilter
import random
import cv2
from skimage import exposure
@ -62,6 +62,16 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')
def uncrop(image, dest_size, paste_loc):
x, y, w, h = paste_loc
base_image = Image.new('RGBA', dest_size)
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image
return image
def apply_overlay(image, paste_loc, index, overlays):
if overlays is None or index >= len(overlays):
return image
@ -69,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays):
overlay = overlays[index]
if paste_loc is not None:
x, y, w, h = paste_loc
base_image = Image.new('RGBA', (overlay.width, overlay.height))
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image
image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@ -140,6 +146,7 @@ class StableDiffusionProcessing:
do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
masks_for_overlay: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = 0
@ -865,11 +872,66 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
# todo: generate masks the old fashioned way
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
# Generate the mask(s) based on similarity between the original and denoised latent vectors
if getattr(p, "image_mask", None) is not None:
# latent_mask = p.nmask[0].float().cpu()
# convert the original mask into a form we use to scale distances for thresholding
# mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
# mask_scalar = mask_scalar / (1.00001-mask_scalar)
# mask_scalar = mask_scalar.numpy()
latent_orig = p.init_latent
latent_proc = samples_ddim
latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1)
kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)):
converted_mask = distance_map.float().cpu().numpy()
converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.9, percentile_max=1, min_width=1)
converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.25, percentile_max=0.75, min_width=1)
# The distance at which opacity of original decreases to 50%
# half_weighted_distance = 1 # * mask_scalar
# converted_mask = converted_mask / half_weighted_distance
converted_mask = 1 / (1 + converted_mask ** 2)
converted_mask = images.smootherstep(converted_mask)
converted_mask = 1 - converted_mask
converted_mask = 255. * converted_mask
converted_mask = converted_mask.astype(np.uint8)
converted_mask = Image.fromarray(converted_mask)
converted_mask = images.resize_image(2, converted_mask, p.width, p.height)
converted_mask = create_binary_mask(converted_mask)
# Remove aliasing artifacts using a gaussian blur.
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
# Expand the mask to fit the whole image if needed.
if p.paste_to is not None:
converted_mask = uncrop(converted_mask,
(overlay_image.width, overlay_image.height),
p.paste_to)
p.masks_for_overlay[i] = converted_mask
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(converted_mask.convert('L')))
p.overlay_images[i] = image_masked.convert('RGBA')
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim,
target_device=devices.cpu,
check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -892,7 +954,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = batch_params.images
def infotext(index=0, use_main_prompt=False):
return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
return create_infotext(p, p.prompts, p.seeds, p.subseeds,
use_main_prompt=use_main_prompt, index=index,
all_negative_prompts=p.negative_prompts)
save_samples = p.save_samples()
@ -923,19 +987,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
# If the intention is to show the output from the model
# that is being composited over the original image,
# we need to keep the original image around
# and use it in the composite step.
original_denoised_image = image.copy()
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
images.save_image(image, p.outpath_samples, "", p.seeds[i],
p.prompts[i], opts.samples_format, info=infotext(i), p=p)
text = infotext(i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.masks_for_overlay[i].convert('RGB')
image_mask_composite = Image.composite(
original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size),
images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
@ -1364,7 +1436,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
nmask: torch.Tensor = field(default=None, init=False)
image_conditioning: torch.Tensor = field(default=None, init=False)
init_img_hash: str = field(default=None, init=False)
mask_for_overlay: Image = field(default=None, init=False)
init_latent: torch.Tensor = field(default=None, init=False)
def __post_init__(self):
@ -1415,12 +1486,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
np_mask = np.array(image_mask).astype(np.float32)
np_mask /= 255
np_mask = 1-pow(1-np_mask, 100)
np_mask *= 255
np_mask = np.clip(np_mask, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@ -1431,13 +1496,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(image_mask).astype(np.float32)
np_mask /= 255
np_mask = 1-pow(1-np_mask, 100)
np_mask *= 255
np_mask = np.clip(np_mask, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
self.masks_for_overlay = []
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@ -1459,10 +1519,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
self.overlay_images.append(image)
self.masks_for_overlay.append(image_mask)
# crop_region is not None if we are doing inpaint full res
if crop_region is not None:
@ -1486,6 +1544,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
if self.masks_for_overlay is not None:
self.masks_for_overlay = self.masks_for_overlay * self.batch_size
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size