diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 95a354dac..0c94d100d 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -120,6 +120,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler): if scheduler.need_inner_model: sigmas_kwargs['inner_model'] = self.model_wrap + if scheduler.label == 'Beta': + p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha + p.extra_generation_params["Beta schedule beta"] = opts.beta_dist_beta + sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu) if discard_next_to_last_sigma: diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py index 84b0abb6a..f4d16e309 100644 --- a/modules/sd_schedulers.py +++ b/modules/sd_schedulers.py @@ -2,6 +2,7 @@ import dataclasses import torch import k_diffusion import numpy as np +from scipy import stats from modules import shared @@ -115,6 +116,17 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device): return torch.FloatTensor(sigs).to(device) +def beta_scheduler(n, sigma_min, sigma_max, inner_model, device): + # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """ + alpha = shared.opts.beta_dist_alpha + beta = shared.opts.beta_dist_beta + timesteps = 1 - np.linspace(0, 1, n) + timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps] + sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps] + sigmas += [0.0] + return torch.FloatTensor(sigmas).to(device) + + schedulers = [ Scheduler('automatic', 'Automatic', None), Scheduler('uniform', 'Uniform', uniform, need_inner_model=True), @@ -127,6 +139,7 @@ schedulers = [ Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True), Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True), Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True), + Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True), ] schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}} diff --git a/modules/shared_options.py b/modules/shared_options.py index a482c7c6d..9f4520274 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -405,6 +405,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"), 'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"), + 'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'), + 'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'), })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 566493266..6a42a04d9 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -259,6 +259,8 @@ axis_options = [ AxisOption("Schedule min sigma", float, apply_override("sigma_min")), AxisOption("Schedule max sigma", float, apply_override("sigma_max")), AxisOption("Schedule rho", float, apply_override("rho")), + AxisOption("Beta schedule alpha", float, apply_override("beta_dist_alpha")), + AxisOption("Beta schedule beta", float, apply_override("beta_dist_beta")), AxisOption("Eta", float, apply_field("eta")), AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')), AxisOption("Denoising", float, apply_field("denoising_strength")),