Rewrote latent_blend() to use in-place operations and to aggressively "del" references with the intention of minimizing allocations and easing garbage collection.

This commit is contained in:
CodeHatchling 2023-12-02 21:08:26 -07:00
parent 73ab982d1b
commit bb04d400c9
1 changed files with 27 additions and 12 deletions

View File

@ -102,29 +102,44 @@ class CFGDenoiser(torch.nn.Module):
The "detail_preservation" factor biases the magnitude interpolation towards
the larger of the two magnitudes.
"""
# Record the original latent vector magnitudes.
# We bring them to a power so that larger magnitudes are favored over smaller ones.
# 64-bit operations are used here to allow large exponents.
a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation
b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation
# NOTE: We use inplace operations wherever possible.
one_minus_t = 1 - t
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation)
# Linearly interpolate the image vectors.
image_interp = a * one_minus_t + b * t
a_scaled = a * one_minus_t
b_scaled = b * t
image_interp = a_scaled
image_interp.add_(b_scaled)
result_type = image_interp.dtype
del a_scaled, b_scaled
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
# 64-bit operations are used here to allow large exponents.
image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001
current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t
b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t
desired_magnitude = a_magnitude
desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation)
del a_magnitude, b_magnitude, one_minus_t
# Change the linearly interpolated image vectors' magnitudes to the value we want.
# This is the last 64-bit operation.
image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype)
image_interp_scaling_factor = desired_magnitude
image_interp_scaling_factor.div_(current_magnitude)
image_interp_scaled = image_interp
image_interp_scaled.mul_(image_interp_scaling_factor)
del current_magnitude
del desired_magnitude
del image_interp
del image_interp_scaling_factor
return image_interp
image_interp_scaled = image_interp_scaled.to(result_type)
del result_type
return image_interp_scaled
def get_modified_nmask(nmask, _sigma):
"""