This commit is contained in:
Patrick von Platen 2022-06-26 00:56:18 +00:00
parent 433cb3f801
commit 135acd83af
1 changed files with 2 additions and 3 deletions

View File

@ -24,13 +24,12 @@ from .scheduling_utils import SchedulerMixin
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
snr=snr, snr=snr,
sigma_min=sigma_min, sigma_min=sigma_min,
sigma_max=sigma_max, sigma_max=sigma_max,
N=N,
sampling_eps=sampling_eps, sampling_eps=sampling_eps,
) )
@ -54,7 +53,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def step_pred(self, result, x, t): def step_pred(self, result, x, t):
t = t * torch.ones(x.shape[0], device=x.device) t = t * torch.ones(x.shape[0], device=x.device)
timestep = (t * (2 - 1)).long() timestep = (t * (len(self.timesteps) - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep] sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where( adjacent_sigma = torch.where(