From 8285a149d8c488ae6c7a566eb85fb5e825145464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 19:20:11 +0300 Subject: [PATCH] add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them) --- modules/sd_samplers.py | 3 +- modules/sd_samplers_cfg_denoiser.py | 50 +++------ modules/sd_samplers_common.py | 140 ++++++++++++++++++++++- modules/sd_samplers_kdiffusion.py | 154 +++----------------------- modules/sd_samplers_timesteps.py | 147 ++++++++++++++++++++++++ modules/sd_samplers_timesteps_impl.py | 135 ++++++++++++++++++++++ 6 files changed, 456 insertions(+), 173 deletions(-) create mode 100644 modules/sd_samplers_timesteps.py create mode 100644 modules/sd_samplers_timesteps_impl.py diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index bea2684c4..fe206894e 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,4 +1,4 @@ -from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared +from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 @@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, *sd_samplers_compvis.samplers_data_compvis, + *sd_samplers_timesteps.samplers_data_timesteps, ] all_samplers_map = {x.name: x for x in all_samplers} diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 33a49783e..166a00c7c 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -39,7 +39,7 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model): + def __init__(self, model, sampler): super().__init__() self.inner_model = model self.mask = None @@ -48,6 +48,7 @@ class CFGDenoiser(torch.nn.Module): self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False + self.sampler = sampler def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -65,6 +66,9 @@ class CFGDenoiser(torch.nn.Module): return denoised + def get_pred_x0(self, x_in, x_out, sigma): + return x_out + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -78,6 +82,9 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + if self.mask is not None: + x = self.init_latent * self.mask + self.nmask * x + batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -170,11 +177,6 @@ class CFGDenoiser(torch.nn.Module): devices.test_for_nans(x_out, "unet") - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes])) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - if is_edit_model: denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) elif skip_uncond: @@ -182,8 +184,16 @@ class CFGDenoiser(torch.nn.Module): else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised + self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) + + if opts.live_preview_content == "Prompt": + preview = self.sampler.last_latent + elif opts.live_preview_content == "Negative prompt": + preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma) + else: + preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma) + + sd_samplers_common.store_latent(preview) after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) cfg_after_cfg_callback(after_cfg_callback_params) @@ -192,27 +202,3 @@ class CFGDenoiser(torch.nn.Module): self.step += 1 return denoised - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - return devices.randn_like(x) - diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 39586b40e..adda963bc 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,9 +1,11 @@ -from collections import namedtuple +import inspect +from collections import namedtuple, deque import numpy as np import torch from PIL import Image from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared from modules.shared import opts, state +import k_diffusion.sampling SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -127,3 +129,139 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() + + +class TorchHijack: + def __init__(self, sampler_noises): + # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based + # implementation. + self.sampler_noises = deque(sampler_noises) + + def __getattr__(self, item): + if item == 'randn_like': + return self.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + + def randn_like(self, x): + if self.sampler_noises: + noise = self.sampler_noises.popleft() + if noise.shape == x.shape: + return noise + + return devices.randn_like(x) + + +class Sampler: + def __init__(self, funcname): + self.funcname = funcname + self.func = funcname + self.extra_params = [] + self.sampler_noises = None + self.stop_at = None + self.eta = None + self.config = None # set by the function calling the constructor + self.last_latent = None + self.s_min_uncond = None + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + + self.eta_option_field = 'eta_ancestral' + self.eta_infotext_field = 'Eta' + + self.conditioning_key = shared.sd_model.model.conditioning_key + + self.model_wrap = None + self.model_wrap_cfg = None + + def callback_state(self, d): + step = d['i'] + + if self.stop_at is not None and step > self.stop_at: + raise InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except RecursionError: + print( + 'Encountered RecursionError during sampling, returning last latent. ' + 'rho >5 with a polyexponential scheduler may cause this error. ' + 'You should try to use a smaller rho value instead.' + ) + return self.last_latent + except InterruptedException: + return self.last_latent + + def number_of_needed_noises(self, p): + return p.steps + + def initialize(self, p) -> dict: + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None + self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) + self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) + self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) + + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params[self.eta_infotext_field] = self.eta + + extra_params_kwargs['eta'] = self.eta + + if len(self.extra_params) > 0: + s_churn = getattr(opts, 's_churn', p.s_churn) + s_tmin = getattr(opts, 's_tmin', p.s_tmin) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf + s_noise = getattr(opts, 's_noise', p.s_noise) + + if s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + p.s_churn = s_churn + p.extra_generation_params['Sigma churn'] = s_churn + if s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_tmin + p.s_tmin = s_tmin + p.extra_generation_params['Sigma tmin'] = s_tmin + if s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_tmax + p.s_tmax = s_tmax + p.extra_generation_params['Sigma tmax'] = s_tmax + if s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_noise + p.s_noise = s_noise + p.extra_generation_params['Sigma noise'] = s_noise + + return extra_params_kwargs + + def create_noise_sampler(self, x, sigmas, p): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + if shared.opts.no_dpmpp_sde_batch_determinism: + return None + + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + + + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 9c9b46d1e..3a2e01b77 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -4,8 +4,7 @@ import inspect import k_diffusion.sampling from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser -from modules.processing import StableDiffusionProcessing -from modules.shared import opts, state +from modules.shared import opts import modules.shared as shared samplers_k_diffusion = [ @@ -54,133 +53,17 @@ k_diffusion_scheduler = { } -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - return devices.randn_like(x) - - -class KDiffusionSampler: +class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) + super().__init__(funcname) + self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.config = None # set by the function calling the constructor - self.last_latent = None - self.s_min_uncond = None + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - # NOTE: These are also defined in the StableDiffusionProcessing class. - # They should have been here to begin with but we're going to - # leave that class __init__ signature alone. - self.s_churn = 0.0 - self.s_tmin = 0.0 - self.s_tmax = float('inf') - self.s_noise = 1.0 - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except RecursionError: - print( - 'Encountered RecursionError during sampling, returning last latent. ' - 'rho >5 with a polyexponential scheduler may cause this error. ' - 'You should try to use a smaller rho value instead.' - ) - return self.last_latent - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p: StableDiffusionProcessing): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) - self.eta = p.eta if p.eta is not None else opts.eta_ancestral - self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - if self.eta != 1.0: - p.extra_generation_params["Eta"] = self.eta - - extra_params_kwargs['eta'] = self.eta - - if len(self.extra_params) > 0: - s_churn = getattr(opts, 's_churn', p.s_churn) - s_tmin = getattr(opts, 's_tmin', p.s_tmin) - s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf - s_noise = getattr(opts, 's_noise', p.s_noise) - - if s_churn != self.s_churn: - extra_params_kwargs['s_churn'] = s_churn - p.s_churn = s_churn - p.extra_generation_params['Sigma churn'] = s_churn - if s_tmin != self.s_tmin: - extra_params_kwargs['s_tmin'] = s_tmin - p.s_tmin = s_tmin - p.extra_generation_params['Sigma tmin'] = s_tmin - if s_tmax != self.s_tmax: - extra_params_kwargs['s_tmax'] = s_tmax - p.s_tmax = s_tmax - p.extra_generation_params['Sigma tmax'] = s_tmax - if s_noise != self.s_noise: - extra_params_kwargs['s_noise'] = s_noise - p.s_noise = s_noise - p.extra_generation_params['Sigma noise'] = s_noise - - return extra_params_kwargs + denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) + self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -232,22 +115,12 @@ class KDiffusionSampler: return sigmas - def create_noise_sampler(self, x, sigmas, p): - """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" - if shared.opts.no_dpmpp_sde_batch_determinism: - return None - - from k_diffusion.sampling import BrownianTreeNoiseSampler - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] - return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) sigmas = self.get_sigmas(p, steps) - sigma_sched = sigmas[steps - t_enc - 1:] + xi = x + noise * sigma_sched[0] extra_params_kwargs = self.initialize(p) @@ -296,12 +169,14 @@ class KDiffusionSampler: extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters + if 'n' in parameters: + extra_params_kwargs['n'] = steps + if 'sigma_min' in parameters: extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in parameters: - extra_params_kwargs['n'] = steps - else: + + if 'sigmas' in parameters: extra_params_kwargs['sigmas'] = sigmas if self.config.options.get('brownian_noise', False): @@ -322,3 +197,4 @@ class KDiffusionSampler: return samples + diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py new file mode 100644 index 000000000..8560d009e --- /dev/null +++ b/modules/sd_samplers_timesteps.py @@ -0,0 +1,147 @@ +import torch +import inspect +from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl +from modules.sd_samplers_cfg_denoiser import CFGDenoiser + +from modules.shared import opts +import modules.shared as shared + +samplers_timesteps = [ + ('k_DDIM', sd_samplers_timesteps_impl.ddim, ['k_ddim'], {}), + ('k_PLMS', sd_samplers_timesteps_impl.plms, ['k_plms'], {}), + ('k_UniPC', sd_samplers_timesteps_impl.unipc, ['k_unipc'], {}), +] + + +samplers_data_timesteps = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_timesteps +] + + +class CompVisTimestepsDenoiser(torch.nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inner_model = model + + def forward(self, input, timesteps, **kwargs): + return self.inner_model.apply_model(input, timesteps, **kwargs) + + +class CompVisTimestepsVDenoiser(torch.nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inner_model = model + + def predict_eps_from_z_and_v(self, x_t, t, v): + return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + + def forward(self, input, timesteps, **kwargs): + model_output = self.inner_model.apply_model(input, timesteps, **kwargs) + e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output) + return e_t + + +class CFGDenoiserTimesteps(CFGDenoiser): + + def __init__(self, model, sampler): + super().__init__(model, sampler) + + self.alphas = model.inner_model.alphas_cumprod + + def get_pred_x0(self, x_in, x_out, sigma): + ts = int(sigma.item()) + + s_in = x_in.new_ones([x_in.shape[0]]) + a_t = self.alphas[ts].item() * s_in + sqrt_one_minus_at = (1 - a_t).sqrt() + + pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt() + + return pred_x0 + + +class CompVisSampler(sd_samplers_common.Sampler): + def __init__(self, funcname, sd_model): + super().__init__(funcname) + + self.eta_option_field = 'eta_ddim' + self.eta_infotext_field = 'Eta DDIM' + + denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(sd_model) + self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self) + + def get_timesteps(self, p, steps): + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: + discard_next_to_last_sigma = True + p.extra_generation_params["Discard penultimate sigma"] = True + + steps += 1 if discard_next_to_last_sigma else 0 + + timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999) + + return timesteps + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + + timesteps = self.get_timesteps(p, steps) + timesteps_sched = timesteps[:t_enc] + + alphas_cumprod = shared.sd_model.alphas_cumprod + sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]]) + sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]]) + + xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'timesteps' in parameters: + extra_params_kwargs['timesteps'] = timesteps_sched + if 'is_img2img' in parameters: + extra_params_kwargs['is_img2img'] = True + + self.model_wrap_cfg.init_latent = x + self.last_latent = x + extra_args = { + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + } + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps = steps or p.steps + timesteps = self.get_timesteps(p, steps) + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'timesteps' in parameters: + extra_params_kwargs['timesteps'] = timesteps + + self.last_latent = x + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py new file mode 100644 index 000000000..48d7e6491 --- /dev/null +++ b/modules/sd_samplers_timesteps_impl.py @@ -0,0 +1,135 @@ +import torch +import tqdm +import k_diffusion.sampling +import numpy as np + +from modules import shared +from modules.models.diffusion.uni_pc import uni_pc + + +@torch.no_grad() +def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + + e_t = model(x, timesteps[index].item() * s_in, **extra_args) + + a_t = alphas[index].item() * s_in + a_prev = alphas_prev[index].item() * s_in + sigma_t = sigmas[index].item() * s_in + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in + + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * k_diffusion.sampling.torch.randn_like(x) + x = a_prev.sqrt() * pred_x0 + dir_xt + noise + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + +@torch.no_grad() +def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + old_eps = [] + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = alphas[index].item() * s_in + a_prev = alphas_prev[index].item() * s_in + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + + # direction pointing to x_t + dir_xt = (1. - a_prev).sqrt() * e_t + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + return x_prev, pred_x0 + + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + ts = timesteps[index].item() * s_in + t_next = timesteps[max(index - 1, 0)].item() * s_in + + e_t = model(x, ts, **extra_args) + + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = model(x_prev, t_next, **extra_args) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + else: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + + x = x_prev + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + +class UniPCCFG(uni_pc.UniPC): + def __init__(self, cfg_model, extra_args, callback, *args, **kwargs): + super().__init__(None, *args, **kwargs) + + def after_update(x, model_x): + callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x}) + self.index += 1 + + self.cfg_model = cfg_model + self.extra_args = extra_args + self.callback = callback + self.index = 0 + self.after_update = after_update + + def get_model_input_time(self, t_continuous): + return (t_continuous - 1. / self.noise_schedule.total_N) * 1000. + + def model(self, x, t): + t_input = self.get_model_input_time(t) + + res = self.cfg_model(x, t_input, **self.extra_args) + + return res + + +def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + + ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means + unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant) + x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) + + return x