diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 9a29f1ae4..28c2136fe 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,49 +1,11 @@ -from collections import deque -import torch -import inspect -import k_diffusion.sampling -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis - -from modules.shared import opts, state -import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image - -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), -] - -samplers_data_k_diffusion = [ - sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) -] - all_samplers = [ - *samplers_data_k_diffusion, - sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), + *sd_samplers_kdiffusion.samplers_data_k_diffusion, + *sd_samplers_compvis.samplers_data_compvis, ] all_samplers_map = {x.name: x for x in all_samplers} @@ -69,8 +31,8 @@ def create_sampler(name, model): def set_samplers(): global samplers, samplers_for_img2img - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) + hidden = set(shared.opts.hide_samplers) + hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) samplers = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] @@ -83,257 +45,3 @@ def set_samplers(): set_samplers() - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - - -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - if x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.default_eta = 1.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 5b06e341f..3c03d442e 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,4 +1,4 @@ -from collections import namedtuple, deque +from collections import namedtuple import numpy as np import torch from PIL import Image @@ -64,6 +64,7 @@ class InterruptedException(BaseException): # MPS fix for randn in torchsde +# XXX move this to separate file for MPS def torchsde_randn(size, dtype, device, seed): if device.type == 'mps': generator = torch.Generator(devices.cpu).manual_seed(int(seed)) diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 3d35ff72e..885411931 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -1,4 +1,6 @@ import math +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms import numpy as np import torch @@ -7,6 +9,12 @@ from modules.shared import state from modules import sd_samplers_common, prompt_parser, shared +samplers_data_compvis = [ + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), +] + + class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 9a29f1ae4..adb6883e6 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,18 +2,12 @@ from collections import deque import torch import inspect import k_diffusion.sampling -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -# imports for functions that previously were here and are used by other modules -from modules.sd_samplers_common import samples_to_image_grid, sample_to_image - - samplers_k_diffusion = [ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), ('Euler', 'sample_euler', ['k_euler'], {}), @@ -40,50 +34,6 @@ samplers_data_k_diffusion = [ if hasattr(k_diffusion.sampling, funcname) ] -all_samplers = [ - *samplers_data_k_diffusion, - sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), -] -all_samplers_map = {x.name: x for x in all_samplers} - -samplers = [] -samplers_for_img2img = [] -samplers_map = {} - - -def create_sampler(name, model): - if name is not None: - config = all_samplers_map.get(name, None) - else: - config = all_samplers[0] - - assert config is not None, f'bad sampler name: {name}' - - sampler = config.constructor(model) - sampler.config = config - - return sampler - - -def set_samplers(): - global samplers, samplers_for_img2img - - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) - - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] - - samplers_map.clear() - for sampler in all_samplers: - samplers_map[sampler.name.lower()] = sampler.name - for alias in sampler.aliases: - samplers_map[alias.lower()] = sampler.name - - -set_samplers() - sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], @@ -92,6 +42,13 @@ sampler_extra_params = { class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + def __init__(self, model): super().__init__() self.inner_model = model