Rewrote latent_blend() to use in-place operations and to aggressively "del" references with the intention of minimizing allocations and easing garbage collection.
This commit is contained in:
parent
73ab982d1b
commit
bb04d400c9
|
@ -102,29 +102,44 @@ class CFGDenoiser(torch.nn.Module):
|
|||
The "detail_preservation" factor biases the magnitude interpolation towards
|
||||
the larger of the two magnitudes.
|
||||
"""
|
||||
# Record the original latent vector magnitudes.
|
||||
# We bring them to a power so that larger magnitudes are favored over smaller ones.
|
||||
# 64-bit operations are used here to allow large exponents.
|
||||
a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation
|
||||
b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation
|
||||
# NOTE: We use inplace operations wherever possible.
|
||||
|
||||
one_minus_t = 1 - t
|
||||
|
||||
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
|
||||
interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation)
|
||||
|
||||
# Linearly interpolate the image vectors.
|
||||
image_interp = a * one_minus_t + b * t
|
||||
a_scaled = a * one_minus_t
|
||||
b_scaled = b * t
|
||||
image_interp = a_scaled
|
||||
image_interp.add_(b_scaled)
|
||||
result_type = image_interp.dtype
|
||||
del a_scaled, b_scaled
|
||||
|
||||
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
|
||||
# 64-bit operations are used here to allow large exponents.
|
||||
image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001
|
||||
current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
|
||||
|
||||
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
|
||||
a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t
|
||||
b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t
|
||||
desired_magnitude = a_magnitude
|
||||
desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation)
|
||||
del a_magnitude, b_magnitude, one_minus_t
|
||||
|
||||
# Change the linearly interpolated image vectors' magnitudes to the value we want.
|
||||
# This is the last 64-bit operation.
|
||||
image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype)
|
||||
image_interp_scaling_factor = desired_magnitude
|
||||
image_interp_scaling_factor.div_(current_magnitude)
|
||||
image_interp_scaled = image_interp
|
||||
image_interp_scaled.mul_(image_interp_scaling_factor)
|
||||
del current_magnitude
|
||||
del desired_magnitude
|
||||
del image_interp
|
||||
del image_interp_scaling_factor
|
||||
|
||||
return image_interp
|
||||
image_interp_scaled = image_interp_scaled.to(result_type)
|
||||
del result_type
|
||||
|
||||
return image_interp_scaled
|
||||
|
||||
def get_modified_nmask(nmask, _sigma):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue