add new sampler DDIM CFG++
This commit is contained in:
parent
a30b19dd55
commit
663a4d80df
|
@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module):
|
|||
self.model_wrap = None
|
||||
self.p = None
|
||||
|
||||
self.last_noise_uncond = None
|
||||
|
||||
# NOTE: masking before denoising can cause the original latents to be oversmoothed
|
||||
# as the original latents do not have noise
|
||||
self.mask_before_denoising = False
|
||||
|
@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module):
|
|||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
is_cfg_pp = 'CFG++' in self.sampler.config.name
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
|
@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module):
|
|||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
if is_cfg_pp:
|
||||
self.last_noise_uncond = x_out[-uncond.shape[0]:]
|
||||
self.last_noise_uncond = torch.clone(self.last_noise_uncond)
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
elif is_cfg_pp:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale/12.5) # CFG++ scale of (0, 1) maps to (1.0, 12.5)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import modules.shared as shared
|
|||
|
||||
samplers_timesteps = [
|
||||
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||
('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}),
|
||||
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
|
||||
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
|
||||
]
|
||||
|
|
|
@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
|||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
|
||||
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
|
||||
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
|
||||
"""
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones((x.shape[0]))
|
||||
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||
index = len(timesteps) - 1 - i
|
||||
|
||||
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||
last_noise_uncond = model.last_noise_uncond
|
||||
|
||||
a_t = alphas[index].item() * s_x
|
||||
a_prev = alphas_prev[index].item() * s_x
|
||||
sigma_t = sigmas[index].item() * s_x
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
|
||||
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
|
|
Loading…
Reference in New Issue