fix bug
This commit is contained in:
parent
433cb3f801
commit
135acd83af
|
@ -24,13 +24,12 @@ from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"):
|
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_to_config(
|
self.register_to_config(
|
||||||
snr=snr,
|
snr=snr,
|
||||||
sigma_min=sigma_min,
|
sigma_min=sigma_min,
|
||||||
sigma_max=sigma_max,
|
sigma_max=sigma_max,
|
||||||
N=N,
|
|
||||||
sampling_eps=sampling_eps,
|
sampling_eps=sampling_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -54,7 +53,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
def step_pred(self, result, x, t):
|
def step_pred(self, result, x, t):
|
||||||
t = t * torch.ones(x.shape[0], device=x.device)
|
t = t * torch.ones(x.shape[0], device=x.device)
|
||||||
timestep = (t * (2 - 1)).long()
|
timestep = (t * (len(self.timesteps) - 1)).long()
|
||||||
|
|
||||||
sigma = self.discrete_sigmas.to(t.device)[timestep]
|
sigma = self.discrete_sigmas.to(t.device)[timestep]
|
||||||
adjacent_sigma = torch.where(
|
adjacent_sigma = torch.where(
|
||||||
|
|
Loading…
Reference in New Issue