diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 228de4944..64e14e0c2 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,7 +1,7 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback @@ -115,7 +115,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): if scheduler.need_inner_model: sigmas_kwargs['inner_model'] = self.model_wrap - sigmas = scheduler.function(n=steps, **sigmas_kwargs) + sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu) if discard_next_to_last_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py index 9916cf05a..0165e6a02 100644 --- a/modules/sd_schedulers.py +++ b/modules/sd_schedulers.py @@ -1,19 +1,19 @@ import dataclasses - import torch - import k_diffusion - import numpy as np from modules import shared + def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / sigma + k_diffusion.sampling.to_d = to_d + @dataclasses.dataclass class Scheduler: name: str @@ -25,11 +25,11 @@ class Scheduler: aliases: list = None -def uniform(n, sigma_min, sigma_max, inner_model): - return inner_model.get_sigmas(n) +def uniform(n, sigma_min, sigma_max, inner_model, device): + return inner_model.get_sigmas(n).to(device) -def sgm_uniform(n, sigma_min, sigma_max, inner_model): +def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): start = inner_model.sigma_to_t(torch.tensor(sigma_max)) end = inner_model.sigma_to_t(torch.tensor(sigma_min)) sigs = [ @@ -37,9 +37,10 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model): for ts in torch.linspace(start, end, n + 1)[:-1] ] sigs += [0.0] - return torch.FloatTensor(sigs) + return torch.FloatTensor(sigs).to(device) -def get_align_your_steps_sigmas(n, sigma_min, sigma_max): + +def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device): # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html def loglinear_interp(t_steps, num_steps): """ @@ -65,12 +66,13 @@ def get_align_your_steps_sigmas(n, sigma_min, sigma_max): else: sigmas.append(0.0) - return torch.FloatTensor(sigmas) + return torch.FloatTensor(sigmas).to(device) -def kl_optimal(n, sigma_min, sigma_max): - alpha_min = torch.arctan(torch.tensor(sigma_min)) - alpha_max = torch.arctan(torch.tensor(sigma_max)) - step_indices = torch.arange(n + 1) + +def kl_optimal(n, sigma_min, sigma_max, device): + alpha_min = torch.arctan(torch.tensor(sigma_min, device=device)) + alpha_max = torch.arctan(torch.tensor(sigma_max, device=device)) + step_indices = torch.arange(n + 1, device=device) sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max) return sigmas