add new sampler DDIM CFG++

This commit is contained in:
v0xie 2024-06-16 17:47:21 -07:00
parent a30b19dd55
commit 663a4d80df
3 changed files with 48 additions and 0 deletions

View File

@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None
self.p = None
self.last_noise_uncond = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
self.mask_before_denoising = False
@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module):
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
is_cfg_pp = 'CFG++' in self.sampler.config.name
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module):
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
if is_cfg_pp:
self.last_noise_uncond = x_out[-uncond.shape[0]:]
self.last_noise_uncond = torch.clone(self.last_noise_uncond)
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
elif is_cfg_pp:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale/12.5) # CFG++ scale of (0, 1) maps to (1.0, 12.5)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)

View File

@ -10,6 +10,7 @@ import modules.shared as shared
samplers_timesteps = [
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}),
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
]

View File

@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
return x
@torch.no_grad()
def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
last_noise_uncond = model.last_noise_uncond
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
@torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod