separate part of denoiser code into a function to make it easier for extensions to override it

This commit is contained in:
AUTOMATIC 2022-12-24 18:38:16 +03:00
parent 03d7b39453
commit 0b8acce6a9
1 changed files with 11 additions and 6 deletions

View File

@ -288,6 +288,16 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
return denoised
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@ -329,12 +339,7 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised