add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them)
This commit is contained in:
parent
2d8e4a6544
commit
8285a149d8
|
@ -1,4 +1,4 @@
|
|||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||
|
@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image #
|
|||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
*sd_samplers_compvis.samplers_data_compvis,
|
||||
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, sampler):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
|
@ -48,6 +48,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.sampler = sampler
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
|
@ -65,6 +66,9 @@ class CFGDenoiser(torch.nn.Module):
|
|||
|
||||
return denoised
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
return x_out
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
@ -78,6 +82,9 @@ class CFGDenoiser(torch.nn.Module):
|
|||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
if self.mask is not None:
|
||||
x = self.init_latent * self.mask + self.nmask * x
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
|
@ -170,11 +177,6 @@ class CFGDenoiser(torch.nn.Module):
|
|||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
|
@ -182,8 +184,16 @@ class CFGDenoiser(torch.nn.Module):
|
|||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
preview = self.sampler.last_latent
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
|
||||
else:
|
||||
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
sd_samplers_common.store_latent(preview)
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
|
@ -192,27 +202,3 @@ class CFGDenoiser(torch.nn.Module):
|
|||
self.step += 1
|
||||
return denoised
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from collections import namedtuple
|
||||
import inspect
|
||||
from collections import namedtuple, deque
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
|
||||
from modules.shared import opts, state
|
||||
import k_diffusion.sampling
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
|
@ -127,3 +129,139 @@ def replace_torchsde_browinan():
|
|||
|
||||
|
||||
replace_torchsde_browinan()
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(self, funcname):
|
||||
self.funcname = funcname
|
||||
self.func = funcname
|
||||
self.extra_params = []
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
self.s_churn = 0.0
|
||||
self.s_tmin = 0.0
|
||||
self.s_tmax = float('inf')
|
||||
self.s_noise = 1.0
|
||||
|
||||
self.eta_option_field = 'eta_ancestral'
|
||||
self.eta_infotext_field = 'Eta'
|
||||
|
||||
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||
|
||||
self.model_wrap = None
|
||||
self.model_wrap_cfg = None
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p) -> dict:
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != 1.0:
|
||||
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
if len(self.extra_params) > 0:
|
||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -4,8 +4,7 @@ import inspect
|
|||
import k_diffusion.sampling
|
||||
from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules.shared import opts, state
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
|
||||
samplers_k_diffusion = [
|
||||
|
@ -54,133 +53,17 @@ k_diffusion_scheduler = {
|
|||
}
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model):
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
super().__init__(funcname)
|
||||
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
|
||||
# NOTE: These are also defined in the StableDiffusionProcessing class.
|
||||
# They should have been here to begin with but we're going to
|
||||
# leave that class __init__ signature alone.
|
||||
self.s_churn = 0.0
|
||||
self.s_tmin = 0.0
|
||||
self.s_tmax = float('inf')
|
||||
self.s_noise = 1.0
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
if opts.live_preview_content == "Combined":
|
||||
sd_samplers_common.store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p: StableDiffusionProcessing):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != 1.0:
|
||||
p.extra_generation_params["Eta"] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
if len(self.extra_params) > 0:
|
||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
||||
return extra_params_kwargs
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self)
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
|
@ -232,22 +115,12 @@ class KDiffusionSampler:
|
|||
|
||||
return sigmas
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
|
@ -296,12 +169,14 @@ class KDiffusionSampler:
|
|||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
|
||||
if 'sigma_min' in parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
|
||||
if 'sigmas' in parameters:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
if self.config.options.get('brownian_noise', False):
|
||||
|
@ -322,3 +197,4 @@ class KDiffusionSampler:
|
|||
|
||||
return samples
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
import torch
|
||||
import inspect
|
||||
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
|
||||
samplers_timesteps = [
|
||||
('k_DDIM', sd_samplers_timesteps_impl.ddim, ['k_ddim'], {}),
|
||||
('k_PLMS', sd_samplers_timesteps_impl.plms, ['k_plms'], {}),
|
||||
('k_UniPC', sd_samplers_timesteps_impl.unipc, ['k_unipc'], {}),
|
||||
]
|
||||
|
||||
|
||||
samplers_data_timesteps = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_timesteps
|
||||
]
|
||||
|
||||
|
||||
class CompVisTimestepsDenoiser(torch.nn.Module):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
return self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
|
||||
|
||||
class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inner_model = model
|
||||
|
||||
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
|
||||
return e_t
|
||||
|
||||
|
||||
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||
|
||||
def __init__(self, model, sampler):
|
||||
super().__init__(model, sampler)
|
||||
|
||||
self.alphas = model.inner_model.alphas_cumprod
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
ts = int(sigma.item())
|
||||
|
||||
s_in = x_in.new_ones([x_in.shape[0]])
|
||||
a_t = self.alphas[ts].item() * s_in
|
||||
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||
|
||||
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
||||
|
||||
return pred_x0
|
||||
|
||||
|
||||
class CompVisSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model):
|
||||
super().__init__(funcname)
|
||||
|
||||
self.eta_option_field = 'eta_ddim'
|
||||
self.eta_infotext_field = 'Eta DDIM'
|
||||
|
||||
denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||
self.model_wrap = denoiser(sd_model)
|
||||
self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
|
||||
|
||||
def get_timesteps(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||
discard_next_to_last_sigma = True
|
||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
|
||||
|
||||
return timesteps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
timesteps_sched = timesteps[:t_enc]
|
||||
|
||||
alphas_cumprod = shared.sd_model.alphas_cumprod
|
||||
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
|
||||
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
|
||||
|
||||
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'timesteps' in parameters:
|
||||
extra_params_kwargs['timesteps'] = timesteps_sched
|
||||
if 'is_img2img' in parameters:
|
||||
extra_params_kwargs['is_img2img'] = True
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps = steps or p.steps
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'timesteps' in parameters:
|
||||
extra_params_kwargs['timesteps'] = timesteps
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
import torch
|
||||
import tqdm
|
||||
import k_diffusion.sampling
|
||||
import numpy as np
|
||||
|
||||
from modules import shared
|
||||
from modules.models.diffusion.uni_pc import uni_pc
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||
index = len(timesteps) - 1 - i
|
||||
|
||||
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||
|
||||
a_t = alphas[index].item() * s_in
|
||||
a_prev = alphas_prev[index].item() * s_in
|
||||
sigma_t = sigmas[index].item() * s_in
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
old_eps = []
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = alphas[index].item() * s_in
|
||||
a_prev = alphas_prev[index].item() * s_in
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev).sqrt() * e_t
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
return x_prev, pred_x0
|
||||
|
||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||
index = len(timesteps) - 1 - i
|
||||
ts = timesteps[index].item() * s_in
|
||||
t_next = timesteps[max(index - 1, 0)].item() * s_in
|
||||
|
||||
e_t = model(x, ts, **extra_args)
|
||||
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = model(x_prev, t_next, **extra_args)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
else:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
|
||||
x = x_prev
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UniPCCFG(uni_pc.UniPC):
|
||||
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
|
||||
super().__init__(None, *args, **kwargs)
|
||||
|
||||
def after_update(x, model_x):
|
||||
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
|
||||
self.index += 1
|
||||
|
||||
self.cfg_model = cfg_model
|
||||
self.extra_args = extra_args
|
||||
self.callback = callback
|
||||
self.index = 0
|
||||
self.after_update = after_update
|
||||
|
||||
def get_model_input_time(self, t_continuous):
|
||||
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
|
||||
|
||||
def model(self, x, t):
|
||||
t_input = self.get_model_input_time(t)
|
||||
|
||||
res = self.cfg_model(x, t_input, **self.extra_args)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
|
||||
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
||||
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
|
||||
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||
|
||||
return x
|
Loading…
Reference in New Issue