feature: beta scheduler
This commit is contained in:
parent
b2453d280a
commit
a5f66b5003
|
@ -2,6 +2,7 @@ import dataclasses
|
||||||
import torch
|
import torch
|
||||||
import k_diffusion
|
import k_diffusion
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
@ -115,6 +116,17 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
||||||
return torch.FloatTensor(sigs).to(device)
|
return torch.FloatTensor(sigs).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
||||||
|
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
|
||||||
|
alpha = 0.6
|
||||||
|
beta = 0.6
|
||||||
|
timesteps = 1 - np.linspace(0, 1, n)
|
||||||
|
timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
|
||||||
|
sigmas = [sigma_min + ((x)*(sigma_max-sigma_min)) for x in timesteps] + [0.0]
|
||||||
|
sigmas = torch.FloatTensor(sigmas).to(device)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
schedulers = [
|
schedulers = [
|
||||||
Scheduler('automatic', 'Automatic', None),
|
Scheduler('automatic', 'Automatic', None),
|
||||||
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
||||||
|
@ -127,6 +139,7 @@ schedulers = [
|
||||||
Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
|
Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
|
||||||
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
|
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
|
||||||
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
|
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
|
||||||
|
Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
||||||
|
|
Loading…
Reference in New Issue