diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 113425b2a..bc9b97e45 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -56,6 +56,7 @@ class CFGDenoiser(torch.nn.Module): self.sampler = sampler self.model_wrap = None self.p = None + self.mask_before_denoising = False @property def inner_model(self): @@ -104,7 +105,7 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" - if self.mask is not None: + if self.mask_before_denoising and self.mask is not None: x = self.init_latent * self.mask + self.nmask * x batch_size = len(conds_list) @@ -206,6 +207,9 @@ class CFGDenoiser(torch.nn.Module): else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + if not self.mask_before_denoising and self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) if opts.live_preview_content == "Prompt": diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 6aed29742..c1f534edf 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -49,6 +49,7 @@ class CFGDenoiserTimesteps(CFGDenoiser): super().__init__(sampler) self.alphas = shared.sd_model.alphas_cumprod + self.mask_before_denoising = True def get_pred_x0(self, x_in, x_out, sigma): ts = sigma.to(dtype=int)