Cleaned up code, moved main code contributions into soft_inpainting.py

This commit is contained in:
CodeHatchling 2023-12-04 16:06:58 -07:00
parent 259d33c3c8
commit 976c1053ef
4 changed files with 173 additions and 149 deletions

View File

@ -892,55 +892,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
# Generate the mask(s) based on similarity between the original and denoised latent vectors
if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
# latent_mask = p.nmask[0].float().cpu()
# convert the original mask into a form we use to scale distances for thresholding
# mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
# mask_scalar = mask_scalar / (1.00001-mask_scalar)
# mask_scalar = mask_scalar.numpy()
latent_orig = p.init_latent
latent_proc = samples_ddim
latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1)
kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)):
converted_mask = distance_map.float().cpu().numpy()
converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.9, percentile_max=1, min_width=1)
converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.25, percentile_max=0.75, min_width=1)
# The distance at which opacity of original decreases to 50%
# half_weighted_distance = 1 # * mask_scalar
# converted_mask = converted_mask / half_weighted_distance
converted_mask = 1 / (1 + converted_mask ** 2)
converted_mask = images.smootherstep(converted_mask)
converted_mask = 1 - converted_mask
converted_mask = 255. * converted_mask
converted_mask = converted_mask.astype(np.uint8)
converted_mask = Image.fromarray(converted_mask)
converted_mask = images.resize_image(2, converted_mask, p.width, p.height)
converted_mask = create_binary_mask(converted_mask, round=False)
# Remove aliasing artifacts using a gaussian blur.
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
# Expand the mask to fit the whole image if needed.
if p.paste_to is not None:
converted_mask = uncrop(converted_mask,
(overlay_image.width, overlay_image.height),
p.paste_to)
p.masks_for_overlay[i] = converted_mask
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(converted_mask.convert('L')))
p.overlay_images[i] = image_masked.convert('RGBA')
si.generate_adaptive_masks(latent_orig=p.init_latent,
latent_processed=samples_ddim,
overlay_images=p.overlay_images,
masks_for_overlay=p.masks_for_overlay,
width=p.width,
height=p.height,
paste_to=p.paste_to)
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim,
target_device=devices.cpu,

View File

@ -94,76 +94,6 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
def latent_blend(a, b, t, one_minus_t=None):
"""
Interpolates two latent image representations according to the parameter t,
where the interpolated vectors' magnitudes are also interpolated separately.
The "detail_preservation" factor biases the magnitude interpolation towards
the larger of the two magnitudes.
"""
# NOTE: We use inplace operations wherever possible.
if one_minus_t is None:
one_minus_t = 1 - t
if self.soft_inpainting is None:
return a * one_minus_t + b * t
# Linearly interpolate the image vectors.
a_scaled = a * one_minus_t
b_scaled = b * t
image_interp = a_scaled
image_interp.add_(b_scaled)
result_type = image_interp.dtype
del a_scaled, b_scaled
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
# 64-bit operations are used here to allow large exponents.
current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * one_minus_t
b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * t
desired_magnitude = a_magnitude
desired_magnitude.add_(b_magnitude).pow_(1 / self.soft_inpainting.inpaint_detail_preservation)
del a_magnitude, b_magnitude, one_minus_t
# Change the linearly interpolated image vectors' magnitudes to the value we want.
# This is the last 64-bit operation.
image_interp_scaling_factor = desired_magnitude
image_interp_scaling_factor.div_(current_magnitude)
image_interp_scaled = image_interp
image_interp_scaled.mul_(image_interp_scaling_factor)
del current_magnitude
del desired_magnitude
del image_interp
del image_interp_scaling_factor
image_interp_scaled = image_interp_scaled.to(result_type)
del result_type
return image_interp_scaled
def get_modified_nmask(nmask, _sigma):
"""
Converts a negative mask representing the transparency of the original latent vectors being overlayed
to a mask that is scaled according to the denoising strength for this step.
Where:
0 = fully opaque, infinite density, fully masked
1 = fully transparent, zero density, fully unmasked
We bring this transparency to a power, as this allows one to simulate N number of blending operations
where N can be any positive real value. Using this one can control the balance of influence between
the denoiser and the original latents according to the sigma value.
NOTE: "mask" is not used
"""
if self.soft_inpainting is None:
return nmask
return torch.pow(nmask, (_sigma ** self.soft_inpainting.mask_blend_power) * self.soft_inpainting.mask_blend_scale)
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@ -184,9 +114,12 @@ class CFGDenoiser(torch.nn.Module):
# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
if self.soft_inpainting is None:
x = latent_blend(self.init_latent, x, self.nmask, self.mask)
x = self.init_latent * self.mask + self.nmask * x
else:
x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma))
x = si.latent_blend(self.soft_inpainting,
self.init_latent,
x,
si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@ -290,9 +223,12 @@ class CFGDenoiser(torch.nn.Module):
# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
if self.soft_inpainting is None:
denoised = latent_blend(self.init_latent, denoised, self.nmask, self.mask)
denoised = self.init_latent * self.mask + self.nmask * denoised
else:
denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma))
denoised = si.latent_blend(self.soft_inpainting,
self.init_latent,
denoised,
si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)

View File

@ -4,13 +4,6 @@ class SoftInpaintingSettings:
self.mask_blend_scale = mask_blend_scale
self.inpaint_detail_preservation = inpaint_detail_preservation
def get_paste_fields(self):
return [
(self.mask_blend_power, gen_param_labels.mask_blend_power),
(self.mask_blend_scale, gen_param_labels.mask_blend_scale),
(self.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation),
]
def add_generation_params(self, dest):
dest[enabled_gen_param_label] = True
dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
@ -18,25 +11,169 @@ class SoftInpaintingSettings:
dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
# ------------------- Methods -------------------
def latent_blend(soft_inpainting, a, b, t):
"""
Interpolates two latent image representations according to the parameter t,
where the interpolated vectors' magnitudes are also interpolated separately.
The "detail_preservation" factor biases the magnitude interpolation towards
the larger of the two magnitudes.
"""
import torch
# NOTE: We use inplace operations wherever possible.
one_minus_t = 1 - t
# Linearly interpolate the image vectors.
a_scaled = a * one_minus_t
b_scaled = b * t
image_interp = a_scaled
image_interp.add_(b_scaled)
result_type = image_interp.dtype
del a_scaled, b_scaled
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
# 64-bit operations are used here to allow large exponents.
current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t
b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t
desired_magnitude = a_magnitude
desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
del a_magnitude, b_magnitude, one_minus_t
# Change the linearly interpolated image vectors' magnitudes to the value we want.
# This is the last 64-bit operation.
image_interp_scaling_factor = desired_magnitude
image_interp_scaling_factor.div_(current_magnitude)
image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
image_interp_scaled = image_interp
image_interp_scaled.mul_(image_interp_scaling_factor)
del current_magnitude
del desired_magnitude
del image_interp
del image_interp_scaling_factor
del result_type
return image_interp_scaled
def get_modified_nmask(soft_inpainting, nmask, sigma):
"""
Converts a negative mask representing the transparency of the original latent vectors being overlayed
to a mask that is scaled according to the denoising strength for this step.
Where:
0 = fully opaque, infinite density, fully masked
1 = fully transparent, zero density, fully unmasked
We bring this transparency to a power, as this allows one to simulate N number of blending operations
where N can be any positive real value. Using this one can control the balance of influence between
the denoiser and the original latents according to the sigma value.
NOTE: "mask" is not used
"""
import torch
return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
def generate_adaptive_masks(
latent_orig,
latent_processed,
overlay_images,
masks_for_overlay,
width, height,
paste_to):
import torch
import numpy as np
import modules.processing as proc
import modules.images as images
from PIL import Image, ImageOps, ImageFilter
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
# latent_mask = p.nmask[0].float().cpu()
# convert the original mask into a form we use to scale distances for thresholding
# mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
# mask_scalar = mask_scalar / (1.00001-mask_scalar)
# mask_scalar = mask_scalar.numpy()
latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
converted_mask = distance_map.float().cpu().numpy()
converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.9, percentile_max=1, min_width=1)
converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.25, percentile_max=0.75, min_width=1)
# The distance at which opacity of original decreases to 50%
# half_weighted_distance = 1 # * mask_scalar
# converted_mask = converted_mask / half_weighted_distance
converted_mask = 1 / (1 + converted_mask ** 2)
converted_mask = images.smootherstep(converted_mask)
converted_mask = 1 - converted_mask
converted_mask = 255. * converted_mask
converted_mask = converted_mask.astype(np.uint8)
converted_mask = Image.fromarray(converted_mask)
converted_mask = images.resize_image(2, converted_mask, width, height)
converted_mask = proc.create_binary_mask(converted_mask, round=False)
# Remove aliasing artifacts using a gaussian blur.
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
# Expand the mask to fit the whole image if needed.
if paste_to is not None:
converted_mask = proc. uncrop(converted_mask,
(overlay_image.width, overlay_image.height),
paste_to)
masks_for_overlay[i] = converted_mask
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(converted_mask.convert('L')))
overlay_images[i] = image_masked.convert('RGBA')
# ------------------- Constants -------------------
default = SoftInpaintingSettings(1, 0.5, 4)
enabled_ui_label = "Soft inpainting"
enabled_gen_param_label = "Soft inpainting enabled"
enabled_el_id = "soft_inpainting_enabled"
default = SoftInpaintingSettings(1, 0.5, 4)
ui_labels = SoftInpaintingSettings("Schedule bias", "Preservation strength", "Transition contrast boost")
ui_labels = SoftInpaintingSettings(
"Schedule bias",
"Preservation strength",
"Transition contrast boost")
ui_info = SoftInpaintingSettings(
mask_blend_power="Shifts when preservation of original content occurs during denoising.",
# "Below 1: Stronger preservation near the end (with low sigma)\n"
# "1: Balanced (proportional to sigma)\n"
# "Above 1: Stronger preservation in the beginning (with high sigma)",
mask_blend_scale="How strongly partially masked content should be preserved.",
# "Low values: Favors generated content.\n"
# "High values: Favors original content.",
inpaint_detail_preservation="Amplifies the contrast that may be lost in partially masked regions.")
"Shifts when preservation of original content occurs during denoising.",
"How strongly partially masked content should be preserved.",
"Amplifies the contrast that may be lost in partially masked regions.")
gen_param_labels = SoftInpaintingSettings("Soft inpainting schedule bias", "Soft inpainting preservation strength", "Soft inpainting transition contrast boost")
el_ids = SoftInpaintingSettings("mask_blend_power", "mask_blend_scale", "inpaint_detail_preservation")
gen_param_labels = SoftInpaintingSettings(
"Soft inpainting schedule bias",
"Soft inpainting preservation strength",
"Soft inpainting transition contrast boost")
el_ids = SoftInpaintingSettings(
"mask_blend_power",
"mask_blend_scale",
"inpaint_detail_preservation")
# ------------------- UI -------------------
def gradio_ui():

View File

@ -683,13 +683,6 @@ def create_ui():
with FormRow():
soft_inpainting = si.gradio_ui()
"""
mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power")
mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale")
inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset")
"""
with FormRow():
inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")