Implements "scheduling" for blending of the original latents and a latent blending formula that preserves details in blend transition areas.

This commit is contained in:
CodeHatchling 2023-11-28 16:10:22 -07:00
parent bbba133f05
commit e715e46b6a
1 changed files with 59 additions and 2 deletions

View File

@ -43,6 +43,9 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None
self.mask = None
self.nmask = None
self.mask_blend_power = 1
self.mask_blend_scale = 1
self.mask_blend_offset = 0
self.init_latent = None
self.steps = None
"""number of steps as specified by user in UI"""
@ -56,6 +59,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler
self.model_wrap = None
self.p = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
self.mask_before_denoising = False
@property
@ -89,6 +95,55 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
def latent_blend(a, b, t):
"""
Interpolates two latent image representations according to the parameter t,
where the interpolated vectors' magnitudes are also interpolated separately.
The "detail_preservation" factor biases the magnitude interpolation towards
the larger of the two magnitudes.
"""
# Record the original latent vector magnitudes.
# We bring them to a power so that larger magnitudes are favored over smaller ones.
# 64-bit operations are used here to allow large exponents.
detail_preservation = 32
a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** detail_preservation
b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** detail_preservation
one_minus_t = 1 - t
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / detail_preservation)
# Linearly interpolate the image vectors.
image_interp = a * one_minus_t + b * t
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
# 64-bit operations are used here to allow large exponents.
image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001
# Change the linearly interpolated image vectors' magnitudes to the value we want.
# This is the last 64-bit operation.
image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype)
return image_interp
def get_modified_nmask(nmask, _sigma):
"""
Converts a negative mask representing the transparency of the original latent vectors being overlayed
to a mask that is scaled according to the denoising strength for this step.
Where:
0 = fully opaque, infinite density, fully masked
1 = fully transparent, zero density, fully unmasked
We bring this transparency to a power, as this allows one to simulate N number of blending operations
where N can be any positive real value. Using this one can control the balance of influence between
the denoiser and the original latents according to the sigma value.
NOTE: "mask" is not used
"""
return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale + self.mask_blend_offset)
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@ -105,8 +160,9 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x
x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma))
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@ -207,8 +263,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma))
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)